【JZOJ4970】B 题解

题目大意

  给定 $N, M, K$,数组 $a[N][M], b[N]$。定义

  求第 $K$ 大的 $c[i]$。

  $N \le 250000$ 且 $N$ 为质数,$2 \le m \le 4$
  $0 \le a_{ij} < 1024,\ 0 \le b < m$

题解

  观察式子,发现有两个不好做的地方:

  • $i$ 是乘上 $j$;
  • $b$ 数组是下标。

  所以我们要做一些变化。

  对于 $i=0$ 或 $j=0$,特殊处理掉,以下不考虑。
  那么剩下的 $i$ 和 $j$ 都可以用原根的幂来表示了,这样就将乘法化成加法了。设 $i=g^x,\ j=g^y$
  于是原式变成

  对 $a$ 数组的每一列单独考虑。比如说我们考虑到第 $i$ 列,那我们定义一个 $B$ 数组:$B[j]=(b[j]==i)$
  这样原式变成

  把 $a$ 数组倒过来,即把 $a[y][i]$ 放到 $a[-y][i]$ 的位置上:

  这样就是个循环卷积了。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include<cmath>
#include<cstdio>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
using namespace std;

typedef long long LL;

const int maxn=3e5+5, maxlen=6e5+5;
const double pi=acos(-1), eps=1e-3;

struct Z{
double x,y;
Z(double X=0, double Y=0) {x=X, y=Y;}
};
Z operator +(const Z &a,const Z &b) {return Z(a.x+b.x,a.y+b.y);}
Z operator -(const Z &a,const Z &b) {return Z(a.x-b.x,a.y-b.y);}
Z operator *(const Z &a,const Z &b) {return Z(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}

int n,m,k,a[maxn][5],b[maxn];
LL c[maxn];

int len,rv[maxlen];
Z W[maxlen],B[maxlen],A[maxlen],tp[maxlen];
void DFT(Z *a,int sig)
{
fo(i,0,len-1) tp[rv[i]]=a[i];
for(int m=2; m<=len; m<<=1)
{
int hal=m>>1;
fo(j,0,hal-1)
{
Z w=W[(j*(len/m)*sig+len)%len];
for(int k=j; k<len; k+=m)
{
Z u=tp[k], v=tp[k+hal]*w;
tp[k]=u+v;
tp[k+hal]=u-v;
}
}
}
fo(i,0,len-1) a[i]=tp[i];
}
void FFT(Z *a,Z *b)
{
DFT(a,1), DFT(b,1);
fo(i,0,len-1) a[i]=a[i]*b[i];
DFT(a,-1);
fo(i,0,len-1) a[i].x/=len;
}

int G[22]={5,127,509,2039,8191,32749,65521,131071,249989,249973,249971,2,3,2,7,17,2,17,3,2,5,6};
int gx[maxn];
void Pre()
{
int g;
fo(i,0,10) if (G[i]==n) {g=G[i+11]; break;} //本题加上样例共11个数据,每个n已列出
gx[0]=1;
fo(i,1,n) gx[i]=gx[i-1]*g%n;
for(len=1; len<2*n; len<<=1);
for(int i=0, j, k, l; i<len; rv[k]=i++)
{
W[i]=Z(cos(i*2*pi/len),sin(i*2*pi/len));
for(j=i, k=0, l=1; l<len; j>>=1, l<<=1) k=(k<<1)+(j&1);
}
}

int main()
{
scanf("%d %d %d",&n,&m,&k);
fo(j,0,m-1)
fo(i,0,n-1) scanf("%d",&a[i][j]);
fo(i,0,n-1) scanf("%d",&b[i]);

Pre();

fo(i,0,n-1) c[i]+=a[0][b[0]];
fo(j,1,n-1) c[0]+=a[j][b[0]];

fo(i,0,m-1)
{
fo(j,0,len-1) A[j]=B[j]=Z(0,0);
fo(j,1,n-1)
{
A[j-1]=Z(a[gx[n-j]][i],0);
B[j]=(b[gx[j]]==i) ?Z(1,0) :Z(0,0) ;
}
FFT(A,B);
fo(j,0,len-1) c[gx[j%(n-1)]]+=(LL)(A[j].x+0.5);
}

sort(c,c+n);
printf("%lld\n",c[n-k]);
}