爬树
题解极为不负责任,啥也没看懂,看了好久好久的 (std) 才懂
首先发现如果有一段的 (-1) 的话可以用组合数推出来方案
令长度为 (len) ,给定的左右上界是 ([l,r])
那么方案就是至多选择 (r-l+1) 个位置将权值加一
[sum_{i=0} ^{r-l+1}inom {len} i
]
这东西用一些恒等变形可以转化一下
所以我们直接维护一段上的 (-1) 数量和上下界
然后对于每次的 ([a,b]) 的限制直接考虑对于最左侧和最右侧的信息即可
如果两边不行那么就没有方案
然后这题目主算法是树剖+线段树维护信息:(-1) 的段数,两侧的上下界,
当然,想的东西不算很多
主要是难写,巨难写
几个注意的点:
(1.) 树剖对于 (x o lca) 和 (lca o y) 的做法是不一样的
这里需要分别写两个函数
(2.) 不能把结构体和 (0) 直接 (push\_up) ,得记录是不是加上过了
(3.) 写代码的时候要全神贯注,不能手残啥的
如果像我这种手残脑子还不在的,比如:
(fac[i]=mul(fac[i-1],i-1))
就铁退役了
Code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define reg register
namespace yspm{
inline int read()
{
int res=0,f=1; char k;
while(!isdigit(k=getchar())) if(k=='-') f=-1;
while(isdigit(k)) res=res*10+k-'0',k=getchar();
return res*f;
}
const int N=1e5+10,inf=0x3f3f3f3f,mod=1e9+7;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int del(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return x*y-x*y/mod*mod;}
int fac[N*10],inv[N*10];
inline int C(int n,int m){return (m>=0&&n>=m)?mul(fac[n],mul(inv[m],inv[n-m])):0;}
inline int ksm(int x,int y)
{
int res=1; for(;y;y>>=1,x=mul(x,x)) if(y&1) res=mul(res,x);
return res;
}
struct node{
int to,nxt;
}e[N<<1];
int head[N],sz[N],cnt,tim,dfn[N],ord[N],son[N],dep[N],fa[N],top[N],v[N],n,m;
inline void adde(int u,int v)
{
e[++cnt].to=v; e[cnt].nxt=head[u];
return head[u]=cnt,void();
}
inline void dfs1(int x,int fat)
{
fa[x]=fat; sz[x]=1; dep[x]=dep[fa[x]]+1;
for(reg int i=head[x];i;i=e[i].nxt)
{
int t=e[i].to; if(t==fat) continue;
dfs1(t,x); sz[x]+=sz[t];
if(sz[t]>sz[son[x]]) son[x]=t;
}
return ;
}
inline void dfs2(int x,int topf)
{
top[x]=topf; dfn[x]=++tim; ord[tim]=x;
if(!son[x]) return ; dfs2(son[x],topf);
for(int i=head[x];i;i=e[i].nxt)
{
int t=e[i].to; if(t==son[x]||t==fa[x]) continue;
dfs2(t,t);
} return ;
}
struct point{
int len1,len2,len,llen,rlen,l1,l2,r1,r2,l,r,sum,cnt;
}f1[N<<2],f2[N<<2];
inline point push_up(point a,point b)
{
point ans;
ans.len=a.len+b.len;//区间
ans.cnt=a.cnt+b.cnt-(a.rlen>0&&b.llen>0);//-1 的段数,如果两边是-1,那就合并
ans.llen=a.llen+(a.len==a.llen?b.llen:0);//左侧 -1 的长度
ans.rlen=b.rlen+(b.len==b.rlen?a.rlen:0);//右侧 -1 的长度
if(a.cnt) ans.len1=a.len1+(a.cnt==1&&a.rlen?b.llen:0); else ans.len1=b.len1;
if(b.cnt) ans.len2=b.len2+(b.cnt==1&&b.llen?a.rlen:0); else ans.len2=a.len2;
//左边段和右边段的长度
ans.l=a.len==a.llen?b.l:a.l;
ans.r=b.rlen==b.len?a.r:b.r;
//上下界
if(a.cnt)
{
ans.l1=a.l1;
if(a.cnt==1&&a.rlen) ans.r1=(~b.l)?b.l:inf;
else ans.r1=a.r1;
}
else
{
ans.l1=b.llen?a.r:b.l1;
ans.r1=b.r1;
}
if(b.cnt)
{
ans.r2=b.r2;
if(b.cnt==1&&b.llen) ans.l2=(~a.r)?a.r:-inf;
else ans.l2=b.l2;
}
else
{
ans.l2=a.l2;
ans.r2=a.rlen?b.l:a.r2;
}//两段上下界
if(a.llen==a.len||b.llen==b.len) ans.sum=mul(a.sum,b.sum);
else
{
int t1=b.l-a.r,t2=a.rlen+b.llen;
ans.sum=mul(mul(a.sum,b.sum),C(t1+t2,t2));
}//方案统计,这里就是中间的段合并起来
return ans;
}
inline void push_up(int p)
{
f1[p]=push_up(f1[p<<1],f1[p<<1|1]);
f2[p]=push_up(f2[p<<1|1],f2[p<<1]);
return ;
}
inline void build(int p,int l,int r)
{
if(l==r)
{
f1[p].l=f1[p].r=v[ord[l]];
f1[p].l1=f1[p].l2=-inf; f1[p].r1=f1[p].r2=inf;
f1[p].len=1;
f1[p].len1=f1[p].len2=f1[p].cnt=f1[p].llen=f1[p].rlen=(v[ord[l]]==-1)?1:0;
f1[p].sum=1;
return f2[p]=f1[p],void();
}int mid=(l+r)>>1;
build(p<<1,l,mid); build(p<<1|1,mid+1,r);
return push_up(p);
}
inline void update(int p,int l,int r,int pos,int val)
{
if(l==r)
{
f1[p].l=f1[p].r=val;
f1[p].l1=f1[p].l2=-inf; f1[p].r1=f1[p].r2=inf;
f1[p].len=1; f1[p].len1=f1[p].len2=f1[p].cnt=f1[p].llen=f1[p].rlen=(val==-1)?1:0;
f1[p].sum=1; f2[p]=f1[p];
return ;
}int mid=(l+r)>>1;
if(pos<=mid) update(p<<1,l,mid,pos,val);
else update(p<<1|1,mid+1,r,pos,val);
return push_up(p);
}
inline point ask(int p,int l,int r,int st,int ed,bool fl)
{
if(st<=l&&r<=ed) return fl?f2[p]:f1[p];
int mid=(l+r)>>1;
if(st>mid) return ask(p<<1|1,mid+1,r,st,ed,fl);
if(ed<=mid) return ask(p<<1,l,mid,st,ed,fl);
if(fl) return push_up(ask(p<<1|1,mid+1,r,st,ed,fl),ask(p<<1,l,mid,st,ed,fl));
else return push_up(ask(p<<1,l,mid,st,ed,fl),ask(p<<1|1,mid+1,r,st,ed,fl));
}
inline point query(int x,int y)
{
point ans,tmp;
bool fl=0,fr=0;
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])
{
if(!fl) fl=1,ans=ask(1,1,n,dfn[top[x]],dfn[x],1);
else ans=push_up(ans,ask(1,1,n,dfn[top[x]],dfn[x],1));
x=fa[top[x]];
}
else
{
if(!fr) fr=1,tmp=ask(1,1,n,dfn[top[y]],dfn[y],0);
else tmp=push_up(ask(1,1,n,dfn[top[y]],dfn[y],0),tmp);
y=fa[top[y]];
}
}
if(dep[x]<dep[y])
{
if(!fl) ans=ask(1,1,n,dfn[x],dfn[y],0);
else ans=push_up(ans,ask(1,1,n,dfn[x],dfn[y],0));
}
else
{
if(!fl) ans=ask(1,1,n,dfn[y],dfn[x],1);
else ans=push_up(ans,ask(1,1,n,dfn[y],dfn[x],1));
}if(fr) ans=push_up(ans,tmp);
return ans;
}
inline int calc(int x,int y,int a,int b)
{
point s=query(x,y);
int ans=s.sum;
if(s.cnt)
{
int r1,r2;
if(s.cnt==1)
{
if(!s.llen&&!s.rlen) ans=mul(ans,ksm(C(s.r1-s.l1+s.len1,s.len1),mod-2));
r1=min(s.r1,b)-max(s.l1,a); r2=s.len1;
ans=mul(ans,C(r1+r2,r2));
}
else
{
if(!s.llen) ans=mul(ans,ksm(C(s.r1-s.l1+s.len1,s.len1),mod-2));
if(!s.rlen) ans=mul(ans,ksm(C(s.len2+s.r2-s.l2,s.len2),mod-2));
r1=min(s.r1,b)-max(s.l1,a); r2=s.len1;
ans=mul(ans,C(r1+r2,r2));
r1=min(s.r2,b)-max(s.l2,a); r2=s.len2;
ans=mul(ans,C(r1+r2,r2));
}
} return ans;
}
signed main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
fac[1]=inv[1]=fac[0]=inv[0]=1;
for(reg int i=2;i<N*10;++i) fac[i]=mul(fac[i-1],i),inv[i]=del(mod,mul(inv[mod%i],mod/i));
for(reg int i=1;i<N*10;++i) inv[i]=mul(inv[i],inv[i-1]);
n=read(); m=read();
for(reg int i=1;i<=n;++i) v[i]=read();
for(reg int i=1;i<n;++i)
{
int x=read(),y=read();
adde(x,y); adde(y,x);
} dfs1(1,0); dfs2(1,1); build(1,1,n);
while(m--)
{
if(read()-1)
{
int x=read(),y=read(),a=read(),b=read();
printf("%lld
",calc(x,y,a,b));
}
else
{
int pos=read(),val=read();
update(1,1,n,dfn[pos],val);
}
}
return 0;
}
}
signed main(){return yspm::main();}
本来是应该放到 (October) 泛做的
但是印象太为深刻(题解不负责而且巨难写)就单拎出来了