【JZOJ100019】A 题解

题目大意

  $n \le 10^5$

解法1

  点分。

  对于当前的分治,假设走到了点 $x$,那么 $x$ 的倍数和约数所代表的子树都不能走。
  然后用一个线段树维护当前哪些点能走,哪些不能走。离开这棵子树的时候,把不在这棵子树的标记撤销掉。

  由于要用到撤销,所以要用主席树。

解法2

  先求不合法的路径的数量。

  若 $(a,b)$ 这条路径不合法,则是它内部包含了形如 $(x,kx)$ 的路径。
  对于所有形如 $(x,kx)$ 的路径,假设 $x$ 是 dfs 序小的那个点,$y$ 是dfs 序大的那个点。$(a,b)$ 包含它当且仅当:
  1、若 $x$ 是 $y$ 的祖先,设 $g$ 是这条链上的 $x$ 的儿子,则 $dfn(a)<dfn(g),dfn(y)<dfn(b)<end(y)$ 或 $dfn(y)<dfn(a)<end(y),end(g)<dfn(b)$
  2、若 $x$ 不是 $y$ 的祖先,则 $dfn(x)<dfn(a)<end(x),dfn(y)<dfn(b)<end(y)$

  第一种情况是两个矩形,第二种情况是一个矩形,扫描线求面积并。

代码

//解法2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include<cstdio>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;

typedef long long LL;

const int maxn=1e5+5, maxrec=24e5+5, MX=17;

struct TRST{
int nmin,num;
TRST(int NMIN=0,int NUM=0) {nmin=NMIN, num=NUM;}
};

int n;

int tot,go[2*maxn],next[2*maxn],f1[maxn];
void ins(int x,int y)
{
go[++tot]=y;
next[tot]=f1[x];
f1[x]=tot;
}
int tt[2],rx[2][maxrec],ry[2][maxrec],nt[2][maxrec],fr[2][maxn];
void inr(int ty,int i,int x,int y)
{
rx[ty][++tt[ty]]=x;
ry[ty][tt[ty]]=y;
nt[ty][tt[ty]]=fr[ty][i];
fr[ty][i]=tt[ty];
}

int st[maxn],en[maxn],sum,fa[maxn][MX+5],deep[maxn];
void dfs_dfn(int k,int last)
{
deep[k]=deep[last]+1;
fa[k][0]=last;
fo(j,1,MX) fa[k][j]=fa[fa[k][j-1]][j-1];
st[k]=++sum;
for(int p=f1[k]; p; p=next[p]) if (go[p]!=last) dfs_dfn(go[p],k);
en[k]=sum;
}
int find(int x,int y)
{
fd(j,MX,0) if (deep[fa[y][j]]>deep[x]) y=fa[y][j];
return y;
}

TRST tr[4*maxn];
int bz[4*maxn];
void tr_js(int k,int l,int r)
{
tr[k].num=r-l+1;
if (l==r) return;
int t=k<<1, t1=(l+r)>>1;
tr_js(t,l,t1), tr_js(t+1,t1+1,r);
}
TRST merge(TRST a,TRST b)
{
if (a.nmin<b.nmin) return a;
else if (a.nmin>b.nmin) return b;
else return TRST(a.nmin,a.num+b.num);
}
void update(int k,int t)
{
if (!bz[k]) return;
tr[t].nmin+=bz[k], tr[t+1].nmin+=bz[k];
bz[t]+=bz[k], bz[t+1]+=bz[k];
bz[k]=0;
}
void tr_xg(int k,int l,int r,int x,int y,int z)
{
if (l==x && r==y)
{
tr[k].nmin+=z;
bz[k]+=z;
return;
}
int t=k<<1, t1=(l+r)>>1;
update(k,t);
if (y<=t1) tr_xg(t,l,t1,x,y,z);
else if (x>t1) tr_xg(t+1,t1+1,r,x,y,z);
else tr_xg(t,l,t1,x,t1,z), tr_xg(t+1,t1+1,r,t1+1,y,z);
tr[k]=merge(tr[t],tr[t+1]);
}

LL ans;
void Scanline()
{
tr_js(1,1,n);
fo(i,1,n)
{
for(int p=fr[0][i]; p; p=nt[0][p]) tr_xg(1,1,n,rx[0][p],ry[0][p],1);
ans+=n-tr[1].num;
for(int p=fr[1][i]; p; p=nt[1][p]) tr_xg(1,1,n,rx[1][p],ry[1][p],-1);
}
}

int main()
{
scanf("%d",&n);
fo(i,1,n-1)
{
int x,y;
scanf("%d %d",&x,&y);
ins(x,y), ins(y,x);
}

dfs_dfn(1,0);

fo(i,1,n)
for(int j=2*i; j<=n; j+=i)
{
int x=i, y=j;
if (st[x]>st[y]) swap(x,y);
if (st[y]<=en[x])
{
int g=find(x,y);
inr(0,1,st[y],en[y]), inr(1,st[g]-1,st[y],en[y]);
if (en[g]<n) inr(0,st[y],en[g]+1,n), inr(1,en[y],en[g]+1,n);
} else
{
inr(0,st[x],st[y],en[y]), inr(1,en[x],st[y],en[y]);
}
}

Scanline();

printf("%lld\n",(LL)n*(n-1)/2-ans);
}