kd-tree (k-dimensional tree)

概要

2次元平面上の頂点を木構造で管理する。

計算量

  • 構築: \(O(N \log^2 N)\)

  • 領域探索クエリ: \(O(\sqrt{N} + K)\) (\(K\) は出力する点の数)

  • (一様分布の点集合における)最近傍点クエリ: \(O(\log N)\)

実装

#include<vector>
#include<algorithm>
using namespace std;
using ll = long long;

#define N 500007

struct Point {
  ll x, y; int i;
};

int n;
Point p[N];

bool comp_x(const Point &p1, const Point &p2) { return p1.x < p2.x; }
bool comp_y(const Point &p1, const Point &p2) { return p1.y < p2.y; }

struct Node {
  Node *left, *right;
  Point p;

  Node(Node *left, Node *right, Point p) : left(left), right(right), p(p) {}
};

Node *root;

Node* make(int l, int r, int depth) {
  if(!(l < r)) {
    return nullptr;
  }

  int mid = (l + r) >> 1;
  if(depth % 2) {
    sort(p + l, p + r, comp_x);
  } else {
    sort(p + l, p + r, comp_y);
  }

  return new Node(make(l, mid, depth+1), make(mid+1, r, depth+1), p[mid]);
}

// find nearest neighber
ll find(Node *nd, ll x, ll y, int depth, ll r) {
  if(nd == nullptr) return r;
  Point &p = nd->p;
  ll d = (x - p.x)*(x - p.x) + (y - p.y)*(y - p.y);
  if(r == -1 || d < r) r = d;

  if(depth % 2) {
    if(nd->left != nullptr && x - r <= p.x) {
      r = find(nd->left, x, y, depth+1, r);
    }
    if(nd->right != nullptr && p.x <= x + r) {
      r = find(nd->right, x, y, depth+1, r);
    }
  } else {
    if(nd->left != nullptr && y - r <= p.y) {
      r = find(nd->left, x, y, depth+1, r);
    }
    if(nd->right != nullptr && p.y <= y + r) {
      r = find(nd->right, x, y, depth+1, r);
    }
  }
  return r;
}

// find nodes in [sx, tx]×[sy,ty] (Range Search)
void find(Node *nd, vector<int> &result, ll sx, ll tx, ll sy, ll ty, int depth) {
  Point &p = nd->p;
  if(sx <= p.x && p.x <= tx && sy <= p.y && p.y <= ty) {
    result.push_back(p.i);
  }

  if(depth % 2) {
    if(nd->left != nullptr && sx <= p.x) {
      find(nd->left, result, sx, tx, sy, ty, depth+1);
    }
    if(nd->right != nullptr && p.x <= tx) {
      find(nd->right, result, sx, tx, sy, ty, depth+1);
    }
  } else {
    if(nd->left != nullptr && sy <= p.y) {
      find(nd->left, result, sx, tx, sy, ty, depth+1);
    }
    if(nd->right != nullptr && p.y <= ty) {
      find(nd->right, result, sx, tx, sy, ty, depth+1);
    }
  }
}

Verified

  • AOJ: "DSL_2_C: Range Query - Range Search (kD Tree)": source (C++14, 0.45sec)

  • AtCoder: "AtCoder Regular Contest 010 - D問題: 情報伝搬": source (C++14, 3392ms)

参考


戻る