Luogu P3203 [HNOI2010]弹飞绵羊 题解

Description

题目链接

某天,Lostmonkey 发明了一种超级弹力装置,为了在他的绵羊朋友面前显摆,他邀请小绵羊一起玩个游戏。

游戏一开始,Lostmonkey 在地上沿着一条直线摆上 $n$ 个装置,每个装置设定初始弹力系数 $k_i$​,当绵羊达到第 $i$ 个装置时,它会往后弹 $k_i$​ 步,达到第 $i+k_i$ 个装置,若不存在第 $i+k_i$​ 个装置,则绵羊被弹飞。

绵羊想知道当它从第 $i$ 个装置起步时,被弹几次后会被弹飞。为了使得游戏更有趣,Lostmonkey 可以修改某个弹力装置的弹力系数,任何时候弹力系数均为正整数。

Solution

构建一个虚点 $n+1$,如果绵羊会被弹飞就直接连一条边 $(x,n+1)$,否则直接连向下一个点 $(x,x+a[x])$。

针对每个询问操作,只需要查询下 $n+1$ 的子树大小即可,直接 $split(x,n+1)$,输出 $tr[n+1].sz-1$ 即可(不包括 $n+1$ 这个节点)。

针对每个修改操作,只需要把原先的那条边断掉再重新按照上述方法连一条新边即可。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include<cstdio>
#define N 200010
#define min(x,y) ((x)>(y)?(y):(x))
#define ls tr[x].ch[0]
#define rs tr[x].ch[1]
int n,m,a[N];
struct node{int fa,ch[2],sz,tag;}tr[N];
inline int get(int x){return (tr[tr[x].fa].ch[0]==x?0:(tr[tr[x].fa].ch[1]==x?1:-1));}
inline void add(int x,int y,int k){tr[x].fa=y,(~k&&(tr[y].ch[k]=x,0));}
inline void pushup(int x){tr[x].sz=tr[ls].sz+tr[rs].sz+1;}
inline void pushdown(int x){if(!tr[x].tag) return ;ls^=rs^=ls^=rs;tr[ls].tag^=1,tr[rs].tag^=1,tr[x].tag=0;}
inline void pushall(int x){if(~get(x)) pushall(tr[x].fa);pushdown(x);}
inline void rotate(int x){int y=tr[x].fa,z=tr[y].fa,kx=get(x),ky=get(y);add(tr[x].ch[kx^1],y,kx),add(y,x,kx^1),add(x,z,ky),pushup(y),pushup(x);}
inline void splay(int x){pushall(x);while(~get(x)){if(~get(tr[x].fa)) rotate(get(x)^get(tr[x].fa)?x:tr[x].fa);rotate(x);}}
inline void access(int x){for(int y=0;x;x=tr[y=x].fa) splay(x),rs=y,pushup(x);}
inline void makeroot(int x){access(x),splay(x),tr[x].tag^=1,pushdown(x);}
inline void split(int x,int y){makeroot(x),access(y),splay(y);}
inline int findroot(int x){access(x),splay(x);while(ls) pushdown(x),x=ls;splay(x);return x;}
inline void link(int x,int y){makeroot(x);tr[x].fa=y;}
inline void cut(int x,int y){split(x,y);tr[x].fa=tr[y].ch[0]=0,pushup(x);}
int main(){
scanf("%d",&n);
for(int i=1;i<=n+1;i++) tr[i].sz=1;
for(int i=1;i<=n;i++) scanf("%d",&a[i]),link(i,min(i+a[i],n+1));
scanf("%d",&m);
for(int op,x,y,i=1;i<=m;i++){
scanf("%d",&op);
if(op==1) scanf("%d",&x),x++,split(x,n+1),printf("%d\n",tr[n+1].sz-1);
else scanf("%d%d",&x,&y),x++,cut(x,min(x+a[x],n+1)),link(x,min(x+y,n+1)),a[x]=y;
}
}