Description
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
Input
第一行包含4个正整数N、M、R、P,分别表示树的结点个数、操作个数、根节点序号和取模数(即所有的输出结果均对此取模)。
接下来一行包含N个非负整数,分别依次表示各个节点上初始的数值。
接下来N-1行每行包含两个整数x、y,表示点x和点y之间连有一条边(保证无环且连通)
接下来M行每行包含若干个正整数,每行表示一个操作,格式如下:
操作1: 1 x y z
操作2: 2 x y
操作3: 3 x z
操作4: 4 x
Output
输出包含若干行,分别依次表示每个操作2或操作4所得的结果(对P取模)
Sample Input
5 5 2 24
7 3 7 8 0
1 2
1 5
3 1
4 1
3 4 2
3 2 2
4 5
1 5 1 3
2 1 3
Sample Output
2
21
Hint
时空限制:1s,128M
数据规模:
对于30%的数据: N leq 10, M leq 10 N≤10,M≤10
对于70%的数据: N leq {10}^3, M leq {10}^3 N≤10
3
,M≤10
3
对于100%的数据: N leq {10}^5, M leq {10}^5 N≤10
5
,M≤10
5
(其实,纯随机生成的树LCA+暴力是能过的,可是,你觉得可能是纯随机的么233 )
样例说明:
树的结构如下:
各个操作如下:
故输出应依次为2、21(重要的事情说三遍:记得取模)
题解
树链剖分模板
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
typedef long long LL;
LL read()
{
LL x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int N,M,R,P,A[100050];
int ecnt,place,head[100050];
int deep[100050],parent[100050],size[100050];
int tid[100050],top[100050],rnk[100050],end[100050];
struct edge{int to,nxt;}e[200050];
struct segmentnode{int sum,tag;}seg[400050];
inline void addedge(int u,int v)
{
e[ecnt]=(edge){v,head[u]},head[u]=ecnt++;
e[ecnt]=(edge){u,head[v]},head[v]=ecnt++;
}
inline void pushup(int root)
{
seg[root].sum=(seg[root<<1].sum+seg[root<<1|1].sum)%P;
}
inline void pushdown(int root,int l,int r)
{
int x=seg[root].tag;seg[root].tag=0;
if(x==0)return;
int mid=(l+r)>>1;
(seg[root<<1].sum+=(mid-l+1)*x%P)%=P,(seg[root<<1].tag+=x)%=P;
(seg[root<<1|1].sum+=(r-mid)*x%P)%=P,(seg[root<<1|1].tag+=x)%=P;
}
void getdeep(int root,int step,int fa)
{
deep[root]=step,parent[root]=fa,size[root]=1;
for(int i=head[root];~i;i=e[i].nxt)
{
int v=e[i].to;if(v==fa)continue;
getdeep(v,step+1,root);
size[root]+=size[v];
}
}
void devide(int root,int chain,int fa)
{
tid[root]=end[root]=++place,top[root]=chain,rnk[place]=root;
int k=0;
for(int i=head[root];~i;i=e[i].nxt)
{
int v=e[i].to;if(v==fa)continue;
if(size[v]>size[k])k=v;
}
if(k==0)return;
devide(k,chain,root),end[root]=max(end[root],end[k]);
for(int i=head[root];~i;i=e[i].nxt)
{
int v=e[i].to;if(v==fa||v==k)continue;
devide(v,v,root);end[root]=max(end[root],end[v]);
}
}
void build(int root,int l,int r)
{
seg[root].tag=0;
if(l==r){seg[root].sum=A[rnk[l]];return;}
int mid=(l+r)>>1;
build(root<<1,l,mid),build(root<<1|1,mid+1,r);
pushup(root);
}
void updata(int root,int l,int r,int a,int b,int val)
{
if(l==a&&r==b)
{
(seg[root].sum+=(r-l+1)*val%P)%=P;
(seg[root].tag+=val)%=P;
return;
}
int mid=(l+r)>>1; pushdown(root,l,r);
if(b<=mid)updata(root<<1,l,mid,a,b,val);
else if(a>mid)updata(root<<1|1,mid+1,r,a,b,val);
else
{
updata(root<<1,l,mid,a,mid,val);
updata(root<<1|1,mid+1,r,mid+1,b,val);
}
pushup(root);
}
int getsum(int root,int l,int r,int a,int b)
{
if(l==a&&r==b)return seg[root].sum;
int mid=(l+r)>>1;pushdown(root,l,r);
if(b<=mid)return getsum(root<<1,l,mid,a,b);
else if(a>mid)return getsum(root<<1|1,mid+1,r,a,b);
else
{
int res=getsum(root<<1,l,mid,a,mid);
(res+=getsum(root<<1|1,mid+1,r,mid+1,b))%=P;
return res;
}
}
void updata(int a,int b,int val)
{
while(top[a]!=top[b])
{
if(deep[top[a]]>deep[top[b]])swap(a,b);
updata(1,1,N,tid[top[b]],tid[b],val);
b=parent[top[b]];
}
if(deep[a]>deep[b])swap(a,b);
updata(1,1,N,tid[a],tid[b],val);
}
int getsum(int a,int b)
{
int res=0;
while(top[a]!=top[b])
{
if(deep[top[a]]>deep[top[b]])swap(a,b);
(res+=getsum(1,1,N,tid[top[b]],tid[b]))%=P;
b=parent[top[b]];
}
if(deep[a]>deep[b])swap(a,b);
(res+=getsum(1,1,N,tid[a],tid[b]))%=P;
return res;
}
int main()
{
memset(head,-1,sizeof(head));
N=read(),M=read(),R=read(),P=read();
for(int i=1;i<=N;i++)A[i]=read()%P;
for(int i=1;i<N;i++)
{
int u=read(),v=read();
addedge(u,v);
}
getdeep(R,1,0),devide(R,R,0),build(1,1,N);
for(int i=1;i<=M;i++)
{
int op=read();
if(op==1)
{
int x=read(),y=read(),z=read();
updata(x,y,z);
}
if(op==2)
{
int x=read(),y=read();
printf("%d
",getsum(x,y));
}
if(op==3)
{
int x=read(),z=read();
updata(1,1,N,tid[x],end[x],z);
}
if(op==4)
{
int x=read();
printf("%d
",getsum(1,1,N,tid[x],end[x]));
}
}
return 0;
}