【2020牛客多校第七场 E】NeoMole Synthesis 题解

题目大意

  给定一棵 $n$ 个点的目标树,以及 $m$ 棵模板树,每棵模板树有一个单价 $c_i$,数量无限多。这里的树都是无根树。
  现在要用若干模板树拼成目标树(就是用模板去覆盖目标树,使得目标树的每个点恰好被覆盖一次),求最小代价。

  $n \leq 500,\ m \leq 200$,所有模板树的结点数总和 $N \le 500$
  $c_i \leq 10^6$
  1s

题解

  妙啊。。。

  首先大的框架是个树形 dp。把目标树当有根树,设 $dp_{i,(j,pj)}$ 表示目标树的第 $i$ 个结点,匹配模板里的结点 $j$,$i$ 连向父亲的边匹配 $j$ 连向 $pj$ 的边,的最小代价;设 $g_i$ 表示目标树以 $i$ 为根的子树完全匹配的最小代价。
  $dp$ 数组的状态数是 $O(nN)$ 的。$g_i$ 也可以视为 $\min_{j} dp_{i,(j,0)}$($0$ 就表示 $j$ 没有父亲,是所在其模板树的根),所以以下就是求 $dp$ 数组。

  而这个转移就是让 $i$ 的儿子去匹配 $j$ 的儿子。这是一个二分图最小权匹配:左边一排 $deg_j-1$ 个点表示模板树上 $j$ 的儿子(如果是在做 $dp_{i,(j,0)}$,那么 $j$ 没有父亲,左边就有 $deg_j$ 个点),右边一排 $deg_i-1$ 个点表示目标树上 $i$ 的儿子,左边 $x$ 连向右边 $y$ 的边权是 $dp_{y,(x,j)}-g_y$,最后答案加上 $\sum g_y$。
  注意到只有 $j$ 的儿子数 $\le i$ 的儿子数才有意义,因此这样一次 KM 的时间复杂度是 $O(deg_j^2 \cdot deg_i)$,如果每一对 $(i,j,pj)$ 都做一次 KM 的话,时间复杂度是

  T 掉了。

  所以这里加一个改进,观察到,$dp_{i,(j,pj)}$ 的二分图匹配,实际上就是 $dp_{i,(j,0)}$ 的二分图匹配删去左边 $pj$ 这个点。既然如此,就没必要重新跑一边 KM,直接用最短路退流就好了。
  具体来说,首先做出 $dp_{i,(j,0)}$ 的 KM,答案为 $ans$,然后求出左边每个点走交错路到达右边结点的最短路(从左到右只能退流匹配边,从右到左只能走非匹配边),记为 $aug_{pj}$,那么 $dp_{i,(j,pj)}=ans+aug_{pj}$。这个最短路用 floyd 就好了。

  来分析时间复杂度。KM 的部分现在是

  floyd 的部分需要注意姿势,如果直接对 $deg_i+deg_j$ 个点跑最短路,或者对右边的 $deg_i$ 个点跑最短路,时间都是不对的:

  要对左边的点跑 floyd,用到的性质是“左边的点数 $\le$ 右边的点数”,因此时间复杂度是

  具体来说,对于左边的两点 $x,y$,设它们的 KM 匹配点分别为 $x’,y’$,二分图从 $i$ 到 $j$ 的边权为 $w_{i,j}$,那么在 floyd 的初始距离中,$x$ 到 $y$ 的距离为 $-w_{x,x’}+w_{y,x’}$。这样求出最短路以后,退流 $x$ 的答案为 $aug_x=\min_{y} dis_{x,y}-w_{y,y’}$ 。

  注意一个细节,如果 $j$ 的儿子数比 $i$ 多 1 个(即 $deg_j=deg_i-1+1=deg_i$),那么 $dp_{i,(j,0)}$ 的 KM 是不合法的,但 $dp_{i,(j,pj)}$ 的 KM 都是合法的。这里可以给右边加一个空点,跑 $dp_{i,(j,0)}$ 的 KM 但不要更新答案,然后退流的时候,强制 floyd 的终点是右边这个空点。

代码

// 这里的 KM 跑的是二分图最大权匹配,所以边权取反,floyd 跑最长路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include<bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;i++)
using namespace std;

typedef long long LL;

const int maxn=505;
const LL inf=2139062143;

int n,m,N,c[maxn];
vector<int> e[maxn],em[maxn];
map<pair<int,int>,int> M;

LL lx[maxn],ly[maxn],slack[maxn],mp[maxn][maxn];
int f[maxn],pre[maxn];
bool vis[maxn];
LL KM(int nl,int nr)
{
fo(i,1,nl)
{
lx[i]=-inf;
fo(j,1,nr) lx[i]=max(lx[i],mp[i][j]);
}
memset(ly,0,sizeof(LL)*(nr+1));
memset(f,0,sizeof(int)*(nr+1));
memset(pre,0,sizeof(int)*(nr+1));
fo(i,1,nl)
{
memset(slack,127,sizeof(LL)*(nr+1));
memset(vis,0,sizeof(bool)*(nr+1));
f[0]=i;
int py=0, nextpy;
for(; f[py]; py=nextpy)
{
int px=f[py];
LL d=inf<<3;
vis[py]=1;
fo(j,1,nr) if (!vis[j])
{
if (lx[px]+ly[j]-mp[px][j]<slack[j]) slack[j]=lx[px]+ly[j]-mp[px][j], pre[j]=py;
if (slack[j]<d) d=slack[j], nextpy=j;
}
fo(j,0,nr) if (vis[j]) lx[f[j]]-=d, ly[j]+=d;
else slack[j]-=d;
}
for(; py; py=pre[py]) f[py]=f[pre[py]];
}
LL re=0;
fo(i,1,nl) re+=lx[i];
fo(i,1,nr) re+=ly[i];
return re;
}

LL dis[maxn][maxn],aug[maxn];
int ff[maxn];
void floyd(int nl,int nr,bool ty)
{
fo(y,1,nr) if (f[y]) ff[f[y]]=y;
fo(i,1,nl)
fo(j,1,nl) dis[i][j]=(i==j) ?0 :mp[j][ff[i]]-mp[i][ff[i]];

fo(k,1,nl)
fo(i,1,nl) if (i!=k)
fo(j,1,nl) if (j!=i && j!=k) dis[i][j]=max(dis[i][j],dis[i][k]+dis[k][j]);

if (ty)
{
fo(x,1,nl) aug[x]=dis[x][f[nr]];
} else
{
fo(x,1,nl)
{
aug[x]=-mp[x][ff[x]];
fo(y,1,nl) aug[x]=max(aug[x],dis[x][y]-mp[y][ff[y]]);
}
}
}

LL dp[maxn][3*maxn],g[maxn];
int s0,s[maxn];
void dfs(int k,int last)
{
for(int son:e[k]) if (son!=last) dfs(son,k);

s0=0;
LL gsum=0;
for(int son:e[k]) if (son!=last) s[++s0]=son, gsum+=g[son];
g[k]=inf;
for(int i=1, cnt=0, sz; i<=N; i++, cnt+=sz)
{
sz=em[i].size();
if (sz-1>s0)
{
fo(j,0,sz-1) dp[k][cnt+j]=inf;
continue;
}
fo(x,0,sz-1)
{
int id=M[make_pair(em[i][x],i)];
fo(y,1,s0) mp[x+1][y]=g[s[y]]-dp[s[y]][id];
}
if (sz>s0)
{
s[++s0]=0;
fo(x,1,sz) mp[x][s0]=0;
}
LL ans=gsum-KM(sz,s0);
if (s[s0] || !s0) g[k]=min(g[k],c[i]+ans);

floyd(sz,s0,(s[s0]==0 && s0>0));
fo(x,0,sz-1)
dp[k][cnt+x]=min(ans-aug[x+1],inf);

s0-=(s0>0 && s[s0]==0);
}
}

int main()
{
scanf("%d",&n);
fo(i,2,n)
{
int x,y;
scanf("%d %d",&x,&y);
e[x].push_back(y), e[y].push_back(x);
}
scanf("%d",&m);
fo(i,1,m)
{
int tn,tc;
scanf("%d %d",&tn,&tc);
fo(j,2,tn)
{
int x,y;
scanf("%d %d",&x,&y);
em[N+x].push_back(N+y), em[N+y].push_back(N+x);
}
fo(j,1,tn) c[N+j]=tc;
N+=tn;
}

int tot=0;
fo(i,1,N)
for(int go:em[i]) M[make_pair(i,go)]=tot++;

dfs(1,0);

if (g[1]>=inf) puts("impossible"); else printf("%lld\n",g[1]);
}