【ICPC Camp PTZ-Shanghai 2022 Day2 G】Gross LCS 题解

题目大意

  给定两个数组 $A=\langle a_1,\cdots,a_n \rangle$ 和 $B=\langle b_1,\cdots,b_m \rangle$,定义 $A+x$ 表示 $\langle a_1+x,\cdots,a_n+x \rangle$,求 $\sum_{x=-10^{100}}^{10^{100}} LCS(A+x,B)$。

  $n,m \le 4000,\ \ |a_i|,|b_i| \le 10^8$
  10s,16MB

题解

  首先,这么大的 $x$ 范围里,必定有很多是无贡献的。有贡献的 $x$ 的数量不超过 $nm$ 个,因为 $a_i+x=b_j$ 这样的等式只有 $nm$ 个。
  这相当于说,对于一个固定的 $x$,在一个 $n$ 行 $m$ 列的矩阵上,把符合 $a_i+x=b_j$ 的格子 $(i,j)$ 做标记,在每一个标记点可以走到其右下方的任意标记点,问从左上走到右下的最长距离。
  再往下想,每一对 $(i,j)$ 对应的 $x$ 是唯一的,也就是每种 $x$ 产生的标记总数也只有 $nm$。那么只要想办法把每种 $x$ 产生的标记弄出来,就可以 dp 了。

  如果没有空间限制,那么就 $O(nm)$ 枚举所有的数对即可。这个 dp 是个简单的 LIS。
  现在有空间限制,那么就需要用一些方法按顺序生成 $(x,i,j)$ 三元组。可以把 $b$ 数组从小到大排序(记为 $b’_1,\cdots,b’_m$),初始在堆里加入所有 $(b’_1-a_i,i,1)$,然后每从堆里取出一个元素,就把 $b’$ 向前推进一位,这样就能顺序生成所有 $(x,i,j)$ 三元组了。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;i++)
using namespace std;

typedef pair<int,int> pr;

const int maxn=4005;
const int inf=2e9;

struct B {
int val,i;
};
bool cmpB(const B &a,const B &b) {return a.val<b.val || a.val==b.val && a.i>b.i;}

struct QST {
int x,i,j;
};
bool operator < (const QST &a,const QST &b) {
return a.x>b.x || a.x==b.x && a.i>b.i || a.x==b.x && a.i==b.i && a.j<b.j;
}

int n,m,a[maxn],p[maxn];
B b[maxn];

int c[maxn];
unordered_set<int> RC;
int lowbit(const int &x) {return x&(-x);}
int get(int x) {
int re=0;
for(; x; x-=lowbit(x)) re=max(re,c[x]);
return re;
}
void xg(int x,int z) {
RC.insert(x);
for(; x<=m; x+=lowbit(x)) c[x]=max(c[x],z);
}
void c_clear() {
for(int x:RC)
for(; x<=m; x+=lowbit(x)) c[x]=0;
RC.clear();
}

int main() {
scanf("%d %d",&n,&m);
fo(i,1,n) scanf("%d",&a[i]);
fo(i,1,m) scanf("%d",&b[i].val), b[i].i=i;

sort(b+1,b+1+m,cmpB);
fo(i,1,m) p[b[i].i]=i;

priority_queue<QST> Q;
fo(i,1,n) Q.push((QST){b[1].val-a[i],i,b[1].i});
int lastx=inf, ans=0, curans=0;
while (!Q.empty()) {
QST cur=Q.top(); Q.pop();
if (lastx!=cur.x && lastx!=inf) {
ans+=curans;
curans=0;
c_clear();
}

int dp=get(cur.j-1)+1;
curans=max(curans,dp);
xg(cur.j,dp);

lastx=cur.x;
if (p[cur.j]<m) {
int newj=p[cur.j]+1;
Q.push((QST){b[newj].val-a[cur.i],cur.i,b[newj].i});
}
}
ans+=curans;

printf("%d\n",ans);
}