【Hackerrank World9】【JZOJ5020】Box Operations 题解

题目大意

  给出一个长度为 $n$ 的序列。
  有 4 种操作:
  $1\ l\ r\ c$:给 $a_l,\cdots,a_r$ 加上 $c$;($c$ 可为负)
  $2\ l\ r\ d$:给 $a_l,\cdots,a_r$ 除以 $d$ 下取整;($\lfloor-0.5\rfloor=1$)
  $3\ l\ r$:求 $a_l,\cdots,a_r$ 的最小值;
  $4\ l\ r$:求 $a_l,\cdots,a_r$ 的和。

  $n, q \leq 10^5$
  $|a| \leq 10^9,\ |c| \leq 10^4,\ 2\leq d\leq 10^9$

题解

  这题的关键就是如何用线段树维护区间除法。

  首先,一个数除 log 次,就会变成 0(或者 -1)。所以我们在不断地做除法的过程中,会有很多可以同时操作的连续段。

  (想到这里之后我在考场上是这样做的:对于一个区间,如果他们除出来都是相同的,就打上除法标记,否则递归下去。然后发现除法标记瞬间爆 longlong……然后我就设个阈值,除法标记达到阈值之后强制下传,然后发现阈值居然要小于100……就狗带了)

  为了不打除法标记,我们这样操作:
  对于当前要做除法的区间 $[l,r]$:设区间最小值为 $x$,最大值为 $y$,要除以 $z$,若 $x-\lfloor\frac{x}{z}\rfloor=y-\lfloor\frac{y}{z}\rfloor$(即差量相同),那么直接给这个区间打上减法标记,否则递归下去。

  时间分析:
  对于加法操作:每次最多增加 $2$ 个颜色段,所以总的颜色段最多是 $3n$。
  对于除法操作:相邻的两个颜色段最多 $2\log$ 次操作就合并了。
  所以总的操作次数是 $O(n \log n)$,加上线段树就是 $O(n \log^2 n)$

  (也可以用势能分析:由于相邻的两个颜色段最多 $O(\log)$ 次操作就合并,所以我们把一个颜色段拆成 $\log$ 份来看,设势函数 $\Phi$ 表示当前这些颜色段的数量。一次加法操作就会使势差加 $2\log$,实际用时视为 $1$;一次除法操作假设遍历了 $x$ 个区间,那么势差减 $x$,实际用时也是 $x$。因此估价函数最后算出来是 $O(n \log n)$,加上线段树就是 $O(n \log^2 n)$)

相关题

  UOJ#228 基础数据结构练习题
  此题是区间开方运算,跟区间除法的思路是一样的。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#include<cstdio>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
using namespace std;

typedef long long LL;

const int maxn=1e5+5;

struct TR{
LL nmin,nmax,sum,len;
};

int n,a[maxn];

int ReadInt()
{
char ch=getchar();
int data=0, tag=1;
while ((ch<'0' || ch>'9') && ch!='-') ch=getchar();
do{
if (ch=='-') tag=-1; else data=data*10+ch-'0';
ch=getchar();
} while (ch>='0' && ch<='9' || ch=='-');
return data*tag;
}

LL DIV(LL x,LL y) {return x/y-(x<0 && x%y!=0);}

TR tr[4*maxn];
LL bz[4*maxn];
void update(int k,int ls,int rs)
{
if (bz[k]==0) return;

tr[ls].nmin+=bz[k];
tr[ls].nmax+=bz[k];
tr[ls].sum+=bz[k]*tr[ls].len;
bz[ls]+=bz[k];

tr[rs].nmin+=bz[k];
tr[rs].nmax+=bz[k];
tr[rs].sum+=bz[k]*tr[rs].len;
bz[rs]+=bz[k];

bz[k]=0;
}
void merge(int k,int ls,int rs)
{
tr[k].nmin=(tr[ls].nmin<tr[rs].nmin) ?tr[ls].nmin :tr[rs].nmin;
tr[k].nmax=(tr[ls].nmax>tr[rs].nmax) ?tr[ls].nmax :tr[rs].nmax;
tr[k].sum=tr[ls].sum+tr[rs].sum;
}
void tr_js(int k,int l,int r)
{
tr[k].len=r-l+1;
if (l==r)
{
tr[k].nmin=tr[k].nmax=tr[k].sum=a[l];
return;
}
int t=k<<1, t1=(l+r)>>1;
tr_js(t,l,t1), tr_js(t+1,t1+1,r);
merge(k,t,t+1);
}
void xg_ad(int k,int l,int r,int x,int y,LL z)
{
if (l==x && r==y)
{
tr[k].nmin+=z;
tr[k].nmax+=z;
tr[k].sum+=tr[k].len*z;
bz[k]+=z;
return;
}
int t=k<<1, t1=(l+r)>>1;
update(k,t,t+1);
if (y<=t1) xg_ad(t,l,t1,x,y,z);
else if (x>t1) xg_ad(t+1,t1+1,r,x,y,z);
else xg_ad(t,l,t1,x,t1,z), xg_ad(t+1,t1+1,r,t1+1,y,z);
merge(k,t,t+1);
}
void re_dv(int k,int l,int r,LL z)
{
LL tm=tr[k].nmin-DIV(tr[k].nmin,z);
if (tm==tr[k].nmax-DIV(tr[k].nmax,z))
{
tr[k].nmin-=tm;
tr[k].nmax-=tm;
tr[k].sum-=tm*tr[k].len;
bz[k]-=tm;
return;
}
int t=k<<1, t1=(l+r)>>1;
update(k,t,t+1);
re_dv(t,l,t1,z), re_dv(t+1,t1+1,r,z);
merge(k,t,t+1);
}
void xg_dv(int k,int l,int r,int x,int y,LL z)
{
if (l==x && r==y)
{
re_dv(k,l,r,z);
return;
}
int t=k<<1, t1=(l+r)>>1;
update(k,t,t+1);
if (y<=t1) xg_dv(t,l,t1,x,y,z);
else if (x>t1) xg_dv(t+1,t1+1,r,x,y,z);
else xg_dv(t,l,t1,x,t1,z), xg_dv(t+1,t1+1,r,t1+1,y,z);
merge(k,t,t+1);
}
LL cx_min(int k,int l,int r,int x,int y)
{
if (l==x && r==y) return tr[k].nmin;
int t=k<<1, t1=(l+r)>>1;
update(k,t,t+1);
if (y<=t1) return cx_min(t,l,t1,x,y);
else if (x>t1) return cx_min(t+1,t1+1,r,x,y);
else
{
LL rel=cx_min(t,l,t1,x,t1), rer=cx_min(t+1,t1+1,r,t1+1,y);
return (rel<rer) ?rel :rer;
}
}
LL cx_sum(int k,int l,int r,int x,int y)
{
if (l==x && r==y) return tr[k].sum;
int t=k<<1, t1=(l+r)>>1;
update(k,t,t+1);
if (y<=t1) return cx_sum(t,l,t1,x,y);
else if (x>t1) return cx_sum(t+1,t1+1,r,x,y);
else return cx_sum(t,l,t1,x,t1)+cx_sum(t+1,t1+1,r,t1+1,y);
}

int q;
int main()
{
scanf("%d %d",&n,&q);
fo(i,1,n) a[i]=ReadInt();
tr_js(1,1,n);

while (q--)
{
int ty=ReadInt(), l=ReadInt(), r=ReadInt();
if (ty==1)
{
LL d=ReadInt();
xg_ad(1,1,n,l,r,d);
} else if (ty==2)
{
LL d=ReadInt();
xg_dv(1,1,n,l,r,d);
} else if (ty==3)
{
printf("%lld\n",cx_min(1,1,n,l,r));
} else
{
printf("%lld\n",cx_sum(1,1,n,l,r));
}
}
}