【JZOJ4587】Snow的追寻 题解

题目大意

  有一棵有 $n$ 个节点的根节点为 1 的树,他只能走一条不经过重复节点的路径。
  给出 $q$ 个形如“$x\ y$”的询问,表示他不能走到 $x$ 和 $y$ 的子树中。现在他想知道对于每组询问,他能走的最长路径是多少,如果没有,输出 0。
  $n, q \le 10^5$

【显然】

  可走的地方还是一棵树,而我们要求的路径就是这棵树的直径。

【50%】$n,q \le 2000$

  每个询问暴力找直径。

【100%】$n,q \le 10^5$

  我们需要用线段树维护树的直径
  去掉两棵子树,相当于在 dfs 序上去掉两个区间(注意这两个区间可能有交或存在包含关系),我们把剩下的区间合并就行了。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include<cmath>
#include<cstdio>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;

typedef long long LL;

const int maxn=(1e5)+5, MX=18;

struct TR{
int x,y,len;

TR(int X=0,int Y=0,int LEN=0) {x=X, y=Y, len=LEN;}
};

int n;

int tot,go[2*maxn],next[2*maxn],f1[maxn];
void ins(int x,int y)
{
go[++tot]=y;
next[tot]=f1[x];
f1[x]=tot;
}

int fa[2*maxn][MX+5],deep[maxn],ap[2*maxn],fir[2*maxn],Log[2*maxn],er[MX+5];
void rmq_pre()
{
fo(i,1,ap[0]) fa[i][0]=ap[i], Log[i]=log(i)/log(2);
fo(i,0,MX) er[i]=1<<i;
fo(j,1,MX)
fo(i,1,ap[0])
{
fa[i][j]=fa[i][j-1];
if (i+er[j-1]<=ap[0] && deep[fa[i+er[j-1]][j-1]]<deep[fa[i][j]])
fa[i][j]=fa[i+er[j-1]][j-1];
}
}
int lca(int x,int y)
{
x=fir[x], y=fir[y];
if (x>y) swap(x,y);
int t=Log[y-x+1];
return (deep[fa[x][t]]<deep[fa[y-er[t]+1][t]]) ?fa[x][t] :fa[y-er[t]+1][t] ;
}

int st[maxn],en[maxn],sum,Tbh[maxn];
void dfs_pre(int k,int last)
{
deep[k]=deep[last]+1;
ap[++ap[0]]=k, fir[k]=ap[0];
Tbh[++sum]=k, st[k]=sum;
for(int p=f1[k]; p; p=next[p]) if (go[p]!=last)
{
dfs_pre(go[p],k);
ap[++ap[0]]=k;
}
en[k]=sum;
}

TR tr[4*maxn];
int DIS(int x,int y) {return deep[x]+deep[y]-deep[lca(x,y)]*2;}
TR merge(TR a,TR b)
{
if (a.len==-1) return b;
TR re= (a.len>b.len) ?a :b;
if (DIS(a.x,b.x)>re.len) re=TR(a.x,b.x,DIS(a.x,b.x));
if (DIS(a.x,b.y)>re.len) re=TR(a.x,b.y,DIS(a.x,b.y));
if (DIS(a.y,b.x)>re.len) re=TR(a.y,b.x,DIS(a.y,b.x));
if (DIS(a.y,b.y)>re.len) re=TR(a.y,b.y,DIS(a.y,b.y));
return re;
}
void tr_js(int k,int l,int r)
{
if (l==r)
{
tr[k].x=tr[k].y=Tbh[l];
tr[k].len=0;
return;
}
int t=k<<1, t1=(l+r)>>1;
tr_js(t,l,t1), tr_js(t+1,t1+1,r);
tr[k]=merge(tr[t],tr[t+1]);
}
TR tr_cx(int k,int l,int r,int x,int y)
{
if (l==x && r==y) return tr[k];
int t=k<<1, t1=(l+r)>>1;
if (y<=t1) return tr_cx(t,l,t1,x,y);
else if (x>t1) return tr_cx(t+1,t1+1,r,x,y);
else return merge(tr_cx(t,l,t1,x,t1),tr_cx(t+1,t1+1,r,t1+1,y));
}

int q;
int main()
{
freopen("snow.in","r",stdin);
freopen("snow.out","w",stdout);

scanf("%d %d",&n,&q);
fo(i,1,n-1)
{
int x,y;;
scanf("%d %d",&x,&y);
ins(x,y), ins(y,x);
}

dfs_pre(1,0);
rmq_pre();
tr_js(1,1,n);

while (q--)
{
int x,y;
scanf("%d %d",&x,&y);

if (st[x]>st[y]) swap(x,y);
TR ans=TR(0,0,-1);
if (1<st[x]) ans=merge(ans,tr_cx(1,1,n,1,st[x]-1));
if (en[x]+1<st[y]) ans=merge(ans,tr_cx(1,1,n,en[x]+1,st[y]-1));
int En=max(en[x],en[y]);
if (En<n) ans=merge(ans,tr_cx(1,1,n,En+1,n));

printf("%d\n",(ans.len==-1) ?0 :ans.len );
}
}