【2022icpc Regional 南京 E】Color the Tree 题解

题目大意

  有一棵 $n$ 个节点的树,初始节点颜色全白,每次操作可以选择一个节点 $u$ 和一个距离 $i$,将 $u$ 子树内距离它 $i$ 的点全部染黑。一次距离为 $i$ 的操作的代价为 $a_i$。求最小代价把整棵树染黑。

  $n \leq 10^5,\ \ 1 \leq a_i \leq 10^9$
  多测,$\sum n \leq 3 \times 10^5$,3s

题解

  官方题解给了个建虚树的做法,但这题更有长链剖分的味道,长链剖分的复杂度也更优秀(确切地说,预处理 rmq 要 $O(n \log n)$,剩余都是 $O(n)$)。

  首先我们很容易想到一些关于深度的树形 dp:记 $dp_{x,i}$ 表示以 $x$ 为根的子树内与 $x$ 距离为 $i$ 的点全染黑的最小代价,那么就有

  这种深度相关的树形 dp 看着就很能长链剖分。在长链剖分的框架下,子树合并没什么大问题,唯一的问题在于把 $dp_x$ 向 $dp_{fa_x}$ 转移的时候,需要 $\min$ 上一些 $a_i$。
  但仔细一想,假如某个时刻 $dp_{x,i}$ 与 $a_i$ 取 $\min$ 了,转移到 $fa_x$ 的时候这个值没有被动过(即没有跟 $fa_x$ 的其他子树合并),那么它就要 $\min$ 上 $a_{i+1}$;如果转移到 $fa_{fa_x}$ 它还是没有被动过,那么它就要 $\min$ 上 $a_{i+2}$……所以可以发现,如果一个 dp 值它没有被动过,最终它就会被 $\min$ 上 $a$ 序列的一个区间最小值。那我们就不着急每次都去 $\min$ 了,每当一个 dp 值被更新了以后,我们给它打上一个懒标记 $t$,表示它跟 $a_t$ 取 $\min$ 了,等下次它被更新或者最终求答案的时候,假设那时它离根的距离为 $t’$,那么它就要跟 $a_t, a_{t+1}, \cdots, a_{t’}$ 取 $\min$,这个 rmq 一下就好了。
  有了这个懒标记,长链剖分做 dp 就是线性的了。不过很不幸预处理 rmq 还是 $O(n \log n)$ 的。。。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include<bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;

typedef long long LL;

const int maxn=1e5+5, MX=17;

int n;
LL a[maxn];
vector<int> e[maxn];

int Log[maxn];
LL nmin[MX+2][maxn];
void rmq_pre() {
fo(i,0,n-1) nmin[0][i]=a[i];
fo(i,2,n) Log[i]=Log[i>>1]+1;
fo(j,1,MX)
fo(i,0,n-1) nmin[j][i]=min(nmin[j-1][i],nmin[j-1][i+(1<<(j-1))]);
}
LL rmq(int l,int r) {
int t=Log[r-l+1];
return min(nmin[t][l],nmin[t][r-(1<<t)+1]);
}

int deepest[maxn],lson[maxn],st[maxn],en[maxn],tot;
void dfs_link(int k,int last,int deep) {
deepest[k]=deep;
for(int son:e[k]) if (son!=last) {
dfs_link(son,k,deep+1);
if (deepest[son]>deepest[lson[k]]) lson[k]=son;
deepest[k]=max(deepest[k],deepest[son]);
}
}
int dfs_index(int k,int last) {
st[k]=en[k]=++tot;
if (lson[k]) en[k]=dfs_index(lson[k],k);
for(int son:e[k]) if (son!=last && son!=lson[k]) dfs_index(son,k);
return en[k];
}

LL f[maxn];
int tag[maxn];
void dfs(int k,int last,int deep) {
if (lson[k]) dfs(lson[k],k,deep+1);
f[st[k]]=a[0];
tag[st[k]]=0;
int sondeepest=0;
for(int son:e[k]) if (son!=last && son!=lson[k]) {
dfs(son,k,deep+1);
sondeepest=max(sondeepest,deepest[son]);
}
fo(i,deep,sondeepest) {
int index=st[k]+i-deep;
f[index]=min(f[index],rmq(tag[index],i-deep));
tag[index]=i-deep;
}
for(int son:e[k]) if (son!=last && son!=lson[k]) {
fo(i,deep+1,deepest[son]) {
int sonindex=st[son]+i-(deep+1);
f[sonindex]=min(f[sonindex],rmq(tag[sonindex],i-deep));
f[st[k]+i-deep]+=f[sonindex];
}
}
}

int main() {
int T;
scanf("%d",&T);
while (T--) {
scanf("%d",&n);
fo(i,1,n) {
scanf("%lld",&a[i-1]);
e[i].clear();
}
fo(i,2,n) {
int u,v;
scanf("%d %d",&u,&v);
e[u].push_back(v), e[v].push_back(u);
}

rmq_pre();
tot=0;
memset(lson,0,(n+2)*sizeof(int));
dfs_link(1,0,1);
dfs_index(1,0);

memset(f,0,(n+2)*sizeof(LL));
memset(tag,0,(n+2)*sizeof(int));
dfs(1,0,1);

LL ans=0;
fo(i,1,deepest[1]) {
f[i]=min(f[i],rmq(tag[i],i-1));
ans+=f[i];
}

printf("%lld\n",ans);
}
}