【JZOJ4939】平均值 题解

题目大意

  给定一个长度为 $n$ 的序列 $a_1,\cdots,a_n$,求所有区间的 $mex$ 平均值之和,即

  $1 \leq n \leq 5 \times 10^5,\ \ 0 \leq a_i \leq 5 \times 10^5$

题解

  这题大概是有两种解法,官方题解是转化成 $\sum_{i=0}^{\max\{a\}}\sum_{l \leq r} \frac{[mex(a_l,\cdots,a_r) \geq i]}{r-l+1}$ 来做,但它比较简略我没搞懂。于是又找到了HOWARLI的做法,我这个就是参考他的。被吊打啦

  有一种传统建模是,左端点从右往左移动,用线段树维护右端点的答案。但是在这题,左端点往左移相当于给若干个序列插入一个元素, $mex$ 怎么变就很难搞了。

  但是变通一下,左端点从左往右移,就很好搞了。因为这对于 $mex$ 来说相当于是区间取 $\min$ 的操作。
  当左端点从 $i-1$ 移到 $i$,就相当于删去 $a_{i-1}$,假设 $i-1$ 后面的第一个跟 $a_{i-1}$ 相同的元素是 $a_{next_{i-1}}$,那么删去 $a_{i-1}$ 影响到的右端点范围就是 $[i,next_{i-1})$。在这个范围里所有 $mex$ 值大于 $a_{i-1}$ 的都会变成 $a_{i-1}$。
  由于 $mex$ 是分段递增的,我们在这个范围内找到所有 $mex$ 值大于 $a_{i-1}$ 的段,结算它们的贡献,并更新它们。
  假设一个段 $[l,r]$,$mex$ 值为 $m$。对于里面的一个元素 $x~(x \in [l,r])$,它最早在左端点为 $t$ 时把 $mex$ 值更新为 $m$,那么它要结算的贡献就是 $m(\frac{1}{x-t+1}+\cdots+\frac{1}{x-(i-1)+1})$,即 $m(s_{x-t+1}-s_{x-(i-1)})$(记 $s_n=\sum_{i=1}^n \frac 1i$)。
  于是用线段树维护一下 $mex$ 的分段、区间 $\sum s_{x-t+1}$ 即可。
  每次左端点的移动,最多新增 2 个段,删除若干个段,因此段的总量也是 $O(n)$ 的。时间复杂度 $O(n \log n)$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;

typedef long long LL;

const int maxn=5e5+5;
const LL mo=998244353;

struct TR{
int minmex,maxmex;
LL st;
};

int n,a[maxn];

LL Pow(LL x,LL y)
{
LL re=1;
for(; y; y>>=1, x=x*x%mo) if (y&1) re=re*x%mo;
return re;
}

int ap[maxn],nxt[maxn],init_mex[maxn];
LL s[maxn],ss[maxn];
void Pre()
{
fd(i,n,1)
{
nxt[i]=(!ap[a[i]]) ?n+1 :ap[a[i]];
ap[a[i]]=i;
}

memset(ap,0,sizeof(ap));
int mex=0;
fo(i,1,n)
{
ap[a[i]]=1;
while (ap[mex]) mex++;
init_mex[i]=mex;
}

fo(i,1,n) s[i]=(s[i-1]+Pow(i,mo-2))%mo, ss[i]=(ss[i-1]+s[i])%mo;
}

TR tr[4*maxn];
pair<int,int> bz[4*maxn];
TR merge(const TR &a,const TR &b)
{
return (TR){min(a.minmex,b.minmex),max(a.maxmex,b.maxmex),(a.st+b.st)%mo};
}
void update(int k,int t,int l,int mid,int r)
{
if (!bz[k].first) return;
tr[t]=(TR){bz[k].second,bz[k].second,(ss[mid-bz[k].first+1]-ss[l-bz[k].first]+mo)%mo};
tr[t+1]=(TR){bz[k].second,bz[k].second,(ss[r-bz[k].first+1]-ss[mid+1-bz[k].first]+mo)%mo};
bz[t]=bz[t+1]=bz[k];
bz[k]=make_pair(0,0);
}
void tr_js(int k,int l,int r)
{
if (l==r)
{
tr[k]=(TR){init_mex[r],init_mex[r],s[r]};
return;
}
int t=k<<1, mid=(l+r)>>1;
tr_js(t,l,mid), tr_js(t+1,mid+1,r);
tr[k]=merge(tr[t],tr[t+1]);
}
pair<int,int> cx_bs(int k,int l,int r,int x,int z)
{
if (l==r) return (tr[k].maxmex>z) ?make_pair(l,tr[k].maxmex) :make_pair(-1,-1) ;
int t=k<<1, mid=(l+r)>>1;
update(k,t,l,mid,r);
return (x<=mid && tr[t].maxmex>z) ?cx_bs(t,l,mid,x,z) :cx_bs(t+1,mid+1,r,x,z);
}
int cx_r(int k,int l,int r,int x,int z)
{
if (l==r) return l;
int t=k<<1, mid=(l+r)>>1;
update(k,t,l,mid,r);
return (x<=mid && tr[t+1].minmex>z) ?cx_r(t,l,mid,x,z) :cx_r(t+1,mid+1,r,x,z);
}
LL cx_st(int k,int l,int r,int x,int y)
{
if (x<=l && r<=y) return tr[k].st;
int t=k<<1, mid=(l+r)>>1;
update(k,t,l,mid,r);
LL re=0;
if (x<=mid) re=cx_st(t,l,mid,x,y);
if (mid<y) (re+=cx_st(t+1,mid+1,r,x,y))%=mo;
return re;
}
void tr_xg(int k,int l,int r,int x,int y,pair<int,int> z)
{
if (x<=l && r<=y)
{
tr[k]=(TR){z.second,z.second,(ss[r-z.first+1]-ss[l-z.first]+mo)%mo};
bz[k]=z;
return;
}
int t=k<<1, mid=(l+r)>>1;
update(k,t,l,mid,r);
if (x<=mid) tr_xg(t,l,mid,x,y,z);
if (mid<y) tr_xg(t+1,mid+1,r,x,y,z);
tr[k]=merge(tr[t],tr[t+1]);
}

int main()
{
freopen("average.in","r",stdin);
freopen("average.out","w",stdout);

scanf("%d",&n);
fo(i,1,n) scanf("%d",&a[i]);

Pre();
tr_js(1,1,n);

LL ans=0;
fo(i,1,n)
{
if (i>1)
{
for(int l=i, r; l<nxt[i-1]; l=r+1)
{
pair<int,int> t=cx_bs(1,1,n,l,a[i-1]);
l=t.first;
if (l==-1 || l>=nxt[i-1]) break;
r=min(cx_r(1,1,n,l,t.second),nxt[i-1]-1);
(ans+=(cx_st(1,1,n,l,r)-ss[r-(i-1)]+mo+ss[l-(i-1)-1])%mo*t.second)%=mo;
tr_xg(1,1,n,l,r,make_pair(i,a[i-1]));
}
}
(ans+=cx_st(1,1,n,i,i)*(a[i]==0))%=mo;
}

printf("%lld\n",ans);
}