题目
题目链接:https://www.luogu.com.cn/problem/P5305
给定一棵 (n) 个点的有根树,节点标号 (1 sim n),(1) 号节点为根。
给定常数 (k)。
给定 (Q) 个询问,每次询问给定 (x,y)。
求:
[sumlimits_{i le x} ext{depth}( ext{lca}(i,y))^k
]
( ext{lca}(x,y)) 表示节点 (x) 与节点 (y) 在有根树上的最近公共祖先。
( ext{depth}(x)) 表示节点 (x) 的深度,根节点的深度为 (1)。
由于答案可能很大,你只需要输出答案模 (998244353) 的结果。
(n,Qleq 5 imes 10^4;1leq kleq 10^9)。
思路
和 洛谷P4211 LCA 这道题十分相似,唯一的区别就是在 ( ext{dep}) 外面套上了一个 (k) 次方。
原题的做法是离线然后从小到大考虑 (i),树剖+线段树把根节点到 (i) 的路径全部加一,询问根节点到 (r) 的权值和减去根节点到 (l-1) 的权值和。
那么依然考虑是否能给每一个节点一个权值,这样从 (x) 到根节点的路径权值和恰好等于 ( ext{dep}(x)^k)。
那么显然对于一个点 (x),我们把它的权值设为 ( ext{dep}(x)^k-( ext{dep}(x)-1)^k) 即可。
那么其他部分依然一样,只不过线段树上一个区间 ([l,r]) 的权值和就变成了区间内点的权值和乘区间加一的次数。依然可以轻松维护。
时间复杂度 (O(Qlog^2 n))。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=50010,MOD=998244353;
int n,m,Q,tot,head[N],ans[N],top[N],son[N],siz[N],dep[N],fa[N],id[N],rk[N];
struct edge
{
int next,to;
}e[N];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
struct node
{
int x,y,id;
}a[N];
bool cmp(node x,node y)
{
return x.x<y.x;
}
void dfs1(int x)
{
dep[x]=dep[fa[x]]+1; siz[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
dfs1(v);
siz[x]+=siz[v];
if (siz[v]>siz[son[x]]) son[x]=v;
}
}
void dfs2(int x,int tp)
{
top[x]=tp; id[x]=++tot; rk[tot]=x;
if (son[x]) dfs2(son[x],tp);
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=son[x]) dfs2(v,v);
}
}
ll fpow(ll x,ll k)
{
ll res=1;
for (;k;k>>=1,x=x*x%MOD)
if (k&1) res=res*x%MOD;
return res;
}
struct SegTree
{
int sum[N*4],ans[N*4],lazy[N*4];
void pushup(int x)
{
sum[x]=(sum[x*2]+sum[x*2+1])%MOD;
ans[x]=(ans[x*2]+ans[x*2+1])%MOD;
}
void pushdown(int x)
{
if (lazy[x])
{
ans[x*2]=(ans[x*2]+1LL*sum[x*2]*lazy[x])%MOD;
ans[x*2+1]=(ans[x*2+1]+1LL*sum[x*2+1]*lazy[x])%MOD;
lazy[x*2]=(lazy[x*2]+lazy[x])%MOD;
lazy[x*2+1]=(lazy[x*2+1]+lazy[x])%MOD;
lazy[x]=0;
}
}
void build(int x,int l,int r)
{
if (l==r)
{
int d=dep[rk[l]];
sum[x]=(fpow(d,m)-fpow(d-1,m)+MOD)%MOD;
return;
}
int mid=(l+r)>>1;
build(x*2,l,mid); build(x*2+1,mid+1,r);
pushup(x);
}
void update(int x,int l,int r,int ql,int qr)
{
if (ql<=l && qr>=r)
{
lazy[x]++; ans[x]=(ans[x]+sum[x])%MOD;
return;
}
pushdown(x);
int mid=(l+r)>>1;
if (ql<=mid) update(x*2,l,mid,ql,qr);
if (qr>mid) update(x*2+1,mid+1,r,ql,qr);
pushup(x);
}
int query(int x,int l,int r,int ql,int qr)
{
if (ql<=l && qr>=r) return ans[x];
pushdown(x);
int mid=(l+r)>>1,res=0;
if (ql<=mid) res+=query(x*2,l,mid,ql,qr);
if (qr>mid) res+=query(x*2+1,mid+1,r,ql,qr);
return res%MOD;
}
}seg;
void upd(int x)
{
for (;x;x=fa[top[x]])
seg.update(1,1,n,id[top[x]],id[x]);
}
int query(int x)
{
int res=0;
for (;x;x=fa[top[x]])
res=(res+seg.query(1,1,n,id[top[x]],id[x]))%MOD;
return res;
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d%d",&n,&Q,&m);
for (int i=2;i<=n;i++)
{
scanf("%d",&fa[i]);
add(fa[i],i);
}
for (int i=1;i<=Q;i++)
{
scanf("%d%d",&a[i].x,&a[i].y);
a[i].id=i;
}
sort(a+1,a+1+Q,cmp);
tot=0; dfs1(1); dfs2(1,1);
seg.build(1,1,n);
for (int i=1,j=1;i<=Q;i++)
{
for (;j<=a[i].x;j++) upd(j);
ans[a[i].id]=query(a[i].y);
}
for (int i=1;i<=Q;i++)
cout<<ans[i]<<"
";
return 0;
}