Portal -->broken qwq
Description
现在有一个树根,标号为(1),我们要加入一些节点使得它变成一棵树
接下来加入的每一个节点都有一个能量值为(v[i]),我们定义(d[i])表示一个节点的儿子数(+1),(son[i])表示这个节点的儿子列表,一个节点的好看度为(w[i]=d[i]*(v[i]+sum w[son[i]]))
现在给出(q)个操作:
类型1:读入(fa,v)表示新插入一个叶子节点,其父亲为(fa)(保证存在),能量值为(v)
类型2:读入(x)表示询问节点(x)的好看度(w[x])
对于每一个类型2,输出答案
范围: 对于100%的数据:q<=200000,v[i]<=10^9
Solution
我觉得很气愤的事情是。。因为这题的题面出了一些偏差导致我看题看了1h+各种大力猜测还是没有看懂。。最后只能去看题解来反推这个题目到底是什么意思。。==
这题的关键在于,我们要思考一下这个(w[i])的组成
会发现其实(w[i])可以看成(i)为根的子树内的每个节点(x)的(v[x])乘上某个系数再加起来
而这个系数其实就是(i)到(x)一路上的(d)值的乘积
然后因为没有强制在线,我们可以离线处理每一个操作,因为插入的话感觉有点麻烦(其实好像也不会。。但是当时受题解的影响就直接写了转化成删除的版本了。。)所以我们考虑反过来处理,把插入转化成删除操作
那所以我们考虑用两棵线段树来分别维护根节点(1)到每一个点的这个系数(我们记为(mul[i])好了),以及(sum mul[i]*v[i]),下标按照(dfn)序来排
那么查询就十分方便了,直接先查(x)子树内的(sum mul[i]*v[i]),然后再查一下(fa[x])到(1)的一路上的(d)的乘积(也就是在第一棵线段树上查(dfn[fa[x]])这个位置的值),然后把这乘积除掉就是答案了
至于修改的话,影响到的应该就是(fa[x])子树内的所有节点(因为删除(x)这个节点相当于将(d[fa[x]]-1)),所以我们要对(fa[x])子树范围内的所有位置乘上一个(inv(d[fa[x]])*(d[fa[x]]-1)),其中(inv())表示的是在(\% 10^9+7)意义下的逆元,然后因为(x)这个节点删掉了,所以我们相当于(x)的子树范围内的所有值直接都变成(0)(断开了嘛就不算了)
然后就。。两棵线段树做完啦ovo为了好好熟悉标记永久化写法这里强制自己写的是标记永久化
代码大概长这个样子
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
using namespace std;
const int N=200010,MOD=1e9+7,SEG=N*4*2;
struct Q{
int op,x,val;
}rec[N];
struct xxx{
int y,nxt;
}a[N*2];
int d[N],v[N],w[N],pre[N],st[N],ed[N];
int h[N],mul[N],lis[N];
int ans[N];
int n,m,tot,dfn_t;
namespace Seg{/*{{{*/
int ch[SEG][2],sum[SEG],tag[SEG],rt[2];
int tot,n;
void pushup(int x){
sum[x]=1LL*sum[ch[x][0]]*tag[ch[x][0]]%MOD;
sum[x]=(1LL*sum[x]+1LL*sum[ch[x][1]]*tag[ch[x][1]]%MOD)%MOD;
}
void _build(int x,int l,int r,int which){
sum[x]=0; tag[x]=1;
if (l==r){
sum[x]=which==0?mul[lis[l]]:1LL*mul[lis[l]]*v[lis[l]]%MOD;
return;
}
int mid=l+r>>1;
ch[x][0]=++tot; _build(ch[x][0],l,mid,which);
ch[x][1]=++tot; _build(ch[x][1],mid+1,r,which);
pushup(x);
}
void build(int which,int _n){rt[which]=++tot; n=_n; _build(rt[which],1,n,which);}
void _update(int x,int l,int r,int lx,int rx,int delta){
if (l<=lx&&rx<=r){
tag[x]=1LL*tag[x]*delta%MOD; return;
}
int mid=lx+rx>>1;
if (r<=mid) _update(ch[x][0],l,r,lx,mid,delta);
else if (l>mid) _update(ch[x][1],l,r,mid+1,rx,delta);
else{
_update(ch[x][0],l,mid,lx,mid,delta);
_update(ch[x][1],mid+1,r,mid+1,rx,delta);
}
pushup(x);
}
void update(int which,int l,int r,int delta){_update(rt[which],l,r,1,n,delta);}
int _query(int x,int l,int r,int lx,int rx){
if (l<=lx&&rx<=r) return 1LL*sum[x]*tag[x]%MOD;
int mid=lx+rx>>1;
if (r<=mid) return 1LL*tag[x]*_query(ch[x][0],l,r,lx,mid)%MOD;
else if (l>mid) return 1LL*tag[x]*_query(ch[x][1],l,r,mid+1,rx)%MOD;
else{
int tmp=1LL*tag[x]*_query(ch[x][0],l,mid,lx,mid)%MOD;
int tmp2=1LL*tag[x]*_query(ch[x][1],mid+1,r,mid+1,rx)%MOD;
return (tmp+tmp2)%MOD;
}
}
int query(int which,int l,int r){return _query(rt[which],l,r,1,n);}
}/*}}}*/
int add(int x,int y){a[++tot].y=y; a[tot].nxt=h[x]; h[x]=tot;}
void dfs(int fa,int x,int now){
int u,sz;
d[x]=1; mul[x]=now; pre[x]=fa;
st[x]=++dfn_t; lis[dfn_t]=x;
for (int i=h[x];i!=-1;i=a[i].nxt) ++d[x];
mul[x]=1LL*mul[x]*d[x]%MOD;
for (int i=h[x];i!=-1;i=a[i].nxt){
u=a[i].y;
if (u==fa) continue;
dfs(x,u,mul[x]);
}
ed[x]=dfn_t;
}
int ksm(int x,int y){
int ret=1,base=x;
for (;y;y>>=1,base=1LL*base*base%MOD)
if (y&1) ret=1LL*ret*base%MOD;
return ret;
}
void del(int x){
int fa=pre[x],inv;
inv=ksm(d[fa],MOD-2);
inv=1LL*(d[fa]-1)*inv%MOD;
--d[fa];
Seg::update(0,st[fa],ed[fa],inv);
Seg::update(1,st[fa],ed[fa],inv);
Seg::update(0,st[x],ed[x],0);
Seg::update(1,st[x],ed[x],0);
}
int query(int x){
int ret=Seg::query(1,st[x],ed[x]),tmp;
if (pre[x]){
tmp=Seg::query(0,st[pre[x]],st[pre[x]]);
tmp=ksm(tmp,MOD-2);
}
else tmp=1;
return 1LL*ret*tmp%MOD;
}
int main(){
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
#endif
int x,y,op;
scanf("%d%d",&v[1],&m);
n=1;
memset(h,-1,sizeof(h));
tot=0;
for (int i=1;i<=m;++i){
scanf("%d",&op);
if (op==1){
scanf("%d%d",&x,&y);
++n; v[n]=y;
add(x,n);
rec[i].x=n; rec[i].op=op; rec[i].val=y;
}
else{
scanf("%d",&x);
rec[i].x=x; rec[i].op=op;
}
}
dfn_t=0;
dfs(0,1,1);
Seg::build(0,n);
Seg::build(1,n);
for (int i=m;i>=1;--i){
if (rec[i].op==1)
del(rec[i].x);
else{
ans[i]=query(rec[i].x);
}
}
for (int i=1;i<=m;++i)
if (rec[i].op==2) printf("%d
",ans[i]);
}
/*
input1
2 5
1 1 3
1 2 5
1 3 7
1 4 11
3 1
output1
344
input2
5 5
1 1 4
1 2 3
2 2
1 2 7
2 1
output2
14
94
*/