题意:给你一棵(n)个节点的树,(q)个询问,每次询问读入(u,v,k,op),需要满足树上有(k)对点的简单路径交都等于(u,v)之间的简单路径,(op=1)表示(k)对点中每个点只能存在于一个点对中,否则每个点可以存在于多个点对中,问那k对点有多少种选法,答案对(998244353)取模。
数据范围:对于(100%)的数据,保证 (1≤n≤10^5,1≤u,v≤n,u
e v,1≤k≤min(n,500),op∈{0,1}),保证每个节点的度数不超过(500)。
我们抓住“两两路径之交是((u,v))”这条性质。 可以发现(u,v)是独立的。我们等价于要求:在(u)的子树中选(k)个点使它们两两(lca)是(u)的方案数,对(v)也求同样的东西,再把两者相乘。如果(u,v)存在祖孙关系,不妨设(u)是(v)的祖先,那么(u)的子树就要改为以(v)的方向作为根方向前提下的子树。 显然为了使两两(lca)是(u),在(u)的每一个儿子中就至多只能选一个点。然后这题就差不多了。
设(g[x][i])为在(x)的子树里选(i)个点的方案数,(ans)为最后的答案,(u,v)为读入的(u,v),钦定(u)为深度小的那个点,(tmp)为(u-v)路径上最靠近(u)的那个点
那么有:
[if(lca(u,v)==u)ans=tmp[k]*g[v][k]
]
[else ans=g[u][k]*g[v][k]
]
注意:(k)个点对是不等价的,比如说我们可以选((i,j))为第一个点对和选((i,j))为第二个点对是两种方案。
代码:
#include<cstdio>
#include<algorithm>
int n,q,cnt,fac[501],inv[501],facinv[501],pre[200001],nxt[200001],h[100001],f[100001][20],size[100001],dep[100001],mod=998244353;
struct oo{
int d[601],du;oo(){d[du=0]=1;}
void add(int x){du++;for(int i=du;i;i--)d[i]=(d[i]+1ll*d[i-1]*x)%mod;}
void del(int x){for(int i=1;i<=du;i++)d[i]=((d[i]-1ll*d[i-1]*x)%mod+mod)%mod;du--;}
int cal(int x,int op){int ans=0;for(int i=op?x-1:0;i<=x;i++)ans=(ans+1ll*d[i]*facinv[x-i])%mod;return (1ll*ans*fac[x])%mod;}
}g[100001];
void add(int x,int y){
pre[++cnt]=y;nxt[cnt]=h[x];h[x]=cnt;
pre[++cnt]=x;nxt[cnt]=h[y];h[y]=cnt;}
void dfs(int x){size[x]=1;
for(int i=1;i<20;i++){if(dep[x]<(1<<i))break;f[x][i]=f[f[x][i-1]][i-1];}
for(int i=h[x];i;i=nxt[i])if(pre[i]!=f[x][0]){dep[pre[i]]=dep[x]+1,f[pre[i]][0]=x,dfs(pre[i]),size[x]+=size[pre[i]];g[x].add(size[pre[i]]);}}
int lca(int x,int y){
if(dep[x]>dep[y])std::swap(x,y);int poor=dep[y]-dep[x];
for(int i=19;i>=0;i--)if(poor&(1<<i))y=f[y][i];
for(int i=19;i>=0;i--)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
if(x==y)return x;return f[x][0];}
int get(int x,int y){int poor=dep[x]-dep[y]-1;for(int i=19;i>=0;i--)if(poor&(1<<i))x=f[x][i];return x;}
int main(){
scanf("%d%d",&n,&q);inv[1]=fac[0]=facinv[0]=1;for(int i=2;i<=500;i++)inv[i]=1ll*inv[mod%i]*(mod-mod/i)%mod;
for(int i=1;i<=500;i++)fac[i]=1ll*fac[i-1]*i%mod,facinv[i]=1ll*facinv[i-1]*inv[i]%mod;
for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),add(x,y);dfs(1);
for(int i=1,u,v,k,op;i<=q;i++){
scanf("%d%d%d%d",&u,&v,&k,&op);if(dep[u]>dep[v])std::swap(u,v);
if(lca(u,v)==u){
int now=get(v,u);oo s=g[u];s.del(size[now]),s.add(n-size[u]);
printf("%d
",(1ll*s.cal(k,op)*g[v].cal(k,op))%mod);}
else printf("%d
",(1ll*g[u].cal(k,op)*g[v].cal(k,op))%mod);}}