KD-Tree

在一堆工程向的代码里翻来翻去总算还是学会了。

KD-Tree

作用

用来解决一类多维空间内的统计问题,在OI中多用于二维平面内的统计问题,如K邻近查找、二维数点。

思想:

其实就是把一个空间割成了一块块,然后暴力统计贡献。以二维空间为例子,每次横着找个中点切一刀,再竖着找个中点切一刀,KD树上的一个节点代表了一个矩形。

二维统计就是类似于线段树,如果当前矩形被完全包含于询问矩形,就把贡献加上。

K临近查找就是在KD-TreeDFS,只不过最优化剪枝了一下而已,但是此算法复杂度似乎被证明是O(n\sqrt{n})的了。

重构

对于插入,我们类似于平衡树的插入即可。

然而这样可能树的高度会有问题,每个点代表的矩形可能也会被拉成扁平状,所以要类似于替罪羊树的暴力重构。

注意:这里的重构要在不平衡的最高点进行,不然一次重构可能会被直接卡成O(n^2)

细节

1.KD-Tree节点的左右儿子是不包含当前节点表示的真实节点的,所以统计的时候要另算。

2.up/kill的时候要判断是否有左右儿子。

3.kill的时候要把节点删干净,一般左右儿子是一定要删的,其他看情况。

优化技巧

其实就是与出题人斗智斗勇,尽量防止自己被卡。

1.K邻近查找和多维最值查询的时候一定要最优化剪枝

2.如果对于实数领域的几何类问题,最好在读入之后把点随机旋转一个角度,这样建出来的KD-Tree会更加平均,可以防止被卡成长条状。

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<ctime>
using namespace std;
#define gc getchar()
inline int read(){
register int x(0),f(1);register char c(gc);
while(c>'9'||c<'0')f=c=='-'?-1:1,c=gc;
while(c>='0'&&c<='9')x=x*10+c-48,c=gc;
return f*x;
}
const int N=610000;
struct node{
int son[2],d[2],mn[2],mx[2],siz;
}t[N];
struct node1{
int d[2];
}a[N];
bool cmp0(node1 x,node1 y){return x.d[0]<y.d[0];}
bool cmp1(node1 x,node1 y){return x.d[1]<y.d[1];}
#define mid ((l+r)>>1)
#define ls(x) t[x].son[0]
#define rs(x) t[x].son[1]
#define cmin(x,y) (x=(x<y?x:y))
#define cmax(x,y) (x=(x>y?x:y))
int sta[N],top,cnt;
int n,m,root,Num;
const double alpha=0.75;
inline int get(){return top?sta[top--]:++cnt;}
inline void up(int x){
if(ls(x)){
cmin(t[x].mn[0],t[ls(x)].mn[0]);cmin(t[x].mn[1],t[ls(x)].mn[1]);
cmax(t[x].mx[0],t[ls(x)].mx[0]);cmax(t[x].mx[1],t[ls(x)].mx[1]);
}
if(rs(x)){
cmin(t[x].mn[1],t[rs(x)].mn[1]);cmax(t[x].mx[0],t[rs(x)].mx[0]);
cmax(t[x].mx[1],t[rs(x)].mx[1]);cmin(t[x].mn[0],t[rs(x)].mn[0]);
}
t[x].siz=t[ls(x)].siz+t[rs(x)].siz+1;
}

int build(int l,int r,int kind){
nth_element(a+l,a+mid,a+1+r,(kind)?cmp1:cmp0);
int x=get();
t[x].d[0]=t[x].mn[0]=t[x].mx[0]=a[mid].d[0];
t[x].d[1]=t[x].mn[1]=t[x].mx[1]=a[mid].d[1];
ls(x)=rs(x)=0;
if(l<mid)ls(x)=build(l,mid-1,kind^1);
if(r>mid)rs(x)=build(mid+1,r,kind^1);
up(x);
return x;
}
inline bool check(int x){
return (t[ls(x)].siz>alpha*t[x].siz||t[rs(x)].siz>alpha*t[x].siz);
}
void kill(int x){
if(ls(x))kill(ls(x));
if(rs(x))kill(rs(x));
a[++Num].d[0]=t[x].d[0];
a[Num].d[1]=t[x].d[1];
sta[++top]=x;
}

inline int rebuild(int u,int kind){
Num=0;
kill(u);
u=build(1,Num,kind);
return u;
}
int insert(int x,node1 u,int kind){
if(!x){
x=get();
t[x].d[0]=t[x].mn[0]=t[x].mx[0]=u.d[0];
t[x].d[1]=t[x].mn[1]=t[x].mx[1]=u.d[1];
ls(x)=rs(x)=0;
t[x].siz=1;
return x;
}
if(check(x))x=rebuild(x,kind);
if(u.d[kind]<t[x].d[kind])ls(x)=insert(ls(x),u,kind^1);
else rs(x)=insert(rs(x),u,kind^1);
up(x);
return x;
}
inline int Dis(int u,int x,int y){
return abs(t[u].d[0]-x)+abs(t[u].d[1]-y);
}
inline int Dist(int u,int x,int y){
return max(y-t[u].mx[1],0)+max(t[u].mn[1]-y,0)+max(x-t[u].mx[0],0)+max(t[u].mn[0]-x,0);
}
int ans;
void query(int u,int x,int y){
ans=min(ans,Dis(u,x,y));
int d[2];
d[0]=ls(u)?Dist(ls(u),x,y):2e9;
d[1]=rs(u)?Dist(rs(u),x,y):2e9;
int tmp=(d[1]<d[0]?1:0);
if(ans>d[tmp])query(t[u].son[tmp],x,y);
if(ans>d[tmp^1])query(t[u].son[tmp^1],x,y);
}
int main(){
n=read();m=read();
register int i;
for(i=1;i<=n;i++)a[i].d[0]=read(),a[i].d[1]=read();
root=build(1,n,0);
for(i=1;i<=m;i++){
int opt=read();
if(opt==1){
int x=read(),y=read();
root=insert(root,(node1){x,y},0);
}
else{
int x=read(),y=read();
ans=2e9;query(root,x,y);
cout<<ans<<'\n';
}
}
return 0;
}
0%