题面
题解
妈呀调了我整整一天……
题解太长了不写了可以去看(shadowice)巨巨的
//minamoto
#include<bits/stdc++.h>
#define R register
#define ull unsigned int
#define fp(i,a,b) for(R int i=(a),I=(b)+1;i<I;++i)
#define fd(i,a,b) for(R int i=(a),I=(b)-1;i>I;--i)
#define go(u) for(int i=T->head[u],v=T->e[i].v;i;i=T->e[i].nx,v=T->e[i].v)
#define gg(u) for(int i=u,v=lg[i&-i];i;i-=i&-i,v=lg[i&-i])
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
R int res,f=1;R char ch;
while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
return res*f;
}
const int N=5005,M=1e7+5,P=998244353,inv2=499122177;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){
R int res=1;
for(;y;y>>=1,x=mul(x,x))if(y&1)res=mul(res,x);
return res;
}
const ull base1=998,base2=244;
int n,k,ans;
int deg[M],s[M],C[N][15],fac[15];
struct Gr{
struct eg{int v,nx;}e[M];int head[M],tot,n;
inline void Add(R int u,R int v){e[++tot]={v,head[u]},head[u]=tot;}
inline void Pre(R int x){n=x,tot=1;fp(i,1,n)head[i]=0;}
int calc(){
int res=0;fp(i,1,n)deg[i]=s[i]=0;
fp(i,2,tot)++deg[e[i].v];
fp(u,1,n)for(int i=head[u];i;i=e[i].nx)s[u]+=deg[u]+deg[e[i].v]-3;
fp(i,1,n)res=add(res,mul(s[i],s[i]));
res=mul(res,inv2);
for(int i=2;i<=tot;i+=2){
int d=deg[e[i].v]+deg[e[i^1].v]-2;
res=(res+P-(d-1)*(d-1))%P;
res=(res+1ll*d*(d-1)*(d-2)/2)%P;
}
return res;
}
}pool[3],*T,*E[2];
int lg[N],pc[N];
struct Tr{
ull f[15],q[15],mdzz;int siz[15],jc[15],to[15],n;
inline void Add(R int u,R int v){to[u]|=(1<<v);}
inline void clr(){fp(i,0,n-1)to[i]=0;n=0;}
void dfs(int u){
siz[u]=1;
gg(to[u])dfs(v),siz[u]+=siz[v];
int t=0;gg(to[u])q[++t]=f[v];
sort(q+1,q+1+t),q[t+1]=q[t]+1,q[0]=q[1]+1,f[u]=siz[u],jc[u]=1;
fp(i,1,t)f[u]=f[u]*base1+q[i];
for(int i=1,s=1,cnt=1;i<=t+1;++i)
q[i]==q[i-1]?(++cnt,s=mul(s,cnt)):(jc[u]=mul(jc[u],s),s=cnt=1);
jc[u]=ksm(jc[u],P-2),f[u]*=base2*siz[u];
}
inline void Pre(){dfs(0),mdzz=f[0];}
int sz[15],vis;
void calc(int u,int S){
sz[u]=1,vis|=(1<<u);
gg(to[u]&S)calc(v,S),sz[u]+=sz[v];
int t=0;gg(to[u]&S)q[++t]=f[v];
sort(q+1,q+1+t),f[u]=sz[u];
fp(i,1,t)f[u]=f[u]*base1+q[i];
f[u]*=base2*sz[u];
}
int ck(int S){
int mx=lg[S&-S];
fp(i,0,n-1)((S>>i&1)&&siz[i]>siz[mx])?mx=i:0;
vis=0,calc(mx,S);return vis==S?mx:-1;
}
}qwq,*ntr;
void Line(Gr *a,Gr *b){
b->Pre((a->tot-1)>>1);
fp(u,1,a->n)for(int i=a->head[u];i;i=a->e[i].nx)for(int j=a->e[i].nx;j;j=a->e[j].nx)
b->Add(i>>1,j>>1),b->Add(j>>1,i>>1);
}
map<ull,int>mp;
int calcnode(int k){
int nw=0,pw=1;k-=4;
while(k--)Line(E[nw],E[pw]),nw^=1,pw^=1;
return E[nw]->calc();
}
int calc(){
ntr->Pre(),E[0]->Pre(ntr->n);
fp(u,0,ntr->n-1)gg(ntr->to[u])E[0]->Add(u+1,v+1),E[0]->Add(v+1,u+1);
int res=calcnode(k);
fp(i,1,(1<<ntr->n)-2){
int p=ntr->ck(i);
((~p)&&mp.find(ntr->f[p])!=mp.end())?res=dec(res,mp[ntr->f[p]]):0;
}
return mp[ntr->mdzz]=res;
}
int g[N],las[15],leaf[15],f[N][15];bool L[15];
void dfs(int u,int fa){
int sz=0;
go(u)v!=fa?(dfs(v,u),++sz):0;
fp(i,0,ntr->n-1){
if(L[i])continue;
if(sz<pc[ntr->to[i]]){f[u][i]=0;continue;}
for(int j=ntr->to[i];j;j=(j-1)&ntr->to[i])g[j]=0;
g[0]=1;
for(int j=T->head[u],v=T->e[j].v;j;j=T->e[j].nx,v=T->e[j].v)if(v!=fa)
for(int k=ntr->to[i];k;k=(k-1)&ntr->to[i])
for(int p=k;p;p-=p&-p)
g[k]=add(g[k],1ll*g[k^(p&-p)]*f[v][lg[p&-p]]%P);
f[u][i]=1ll*g[ntr->to[i]]*ntr->jc[i]%P*C[sz-pc[ntr->to[i]]][leaf[i]]%P*fac[leaf[i]]%P;
}
}
int calctimes(){
fp(i,0,ntr->n-1)las[i]=ntr->to[i],leaf[i]=L[i]=0;
fp(u,0,ntr->n-1)gg(ntr->to[u])ntr->siz[v]==1?(++leaf[u],ntr->to[u]^=(1<<v)):0;
fp(i,0,ntr->n-1)ntr->siz[i]==1?L[i]=true:0;
dfs(1,0);
fp(u,0,ntr->n-1)ntr->to[u]=las[u];
int res=0;
fp(i,1,n)res=add(res,f[i][0]);
return res;
}
int fa[25];
int st[N<<1],top;set<ull>vis;
void solve(int x,int c,int sp,int lim){
if(x==(lim-1)*2+1){
int now=0,tot=0;ntr->clr();
fp(i,1,x-1)st[i]==1?(ntr->to[now]|=(1<<(++tot)),fa[tot]=now,now=tot):now=fa[now];
ntr->n=lim,ntr->Pre();
if(vis.find(ntr->mdzz)!=vis.end())return;
vis.insert(ntr->mdzz);
ans=add(ans,1ll*calctimes()*calc()%P);
return;
}
if(c)st[x]=-1,solve(x+1,c-1,sp,lim);
if(sp<lim-1)st[x]=1,solve(x+1,c+1,sp+1,lim);
}
void spj(){
switch(k){
case 1:printf("%d
",n-1);break;
case 2:{
fp(i,1,n)ans=add(ans,1ll*deg[i]*(deg[i]-1)>>1);
printf("%d
",ans);
break;
}
case 3:{
for(int i=2;i<T->tot;i+=2){
int d=deg[T->e[i].v]+deg[T->e[i^1].v]-2;
ans=add(ans,1ll*d*(d-1)>>1);
}
printf("%d
",ans);
break;
}
case 4:printf("%d
",T->calc());break;
default:{
fp(i,1,k+1)solve(1,0,0,i);
printf("%d
",ans);
break;
}
}
}
int main(){
// freopen("testdata.in","r",stdin);
T=&pool[0],E[0]=&pool[1],E[1]=&pool[2],ntr=&qwq;
n=read(),k=read(),T->Pre(n);
for(R int i=1,u,v;i<n;++i)u=read(),v=read(),++deg[u],++deg[v],T->Add(u,v),T->Add(v,u);
fp(i,2,N-1)lg[i]=lg[i>>1]+1;
fp(i,1,N-1)pc[i]=pc[i>>1]+(i&1);
fp(i,0,N-1)C[i][0]=1;fac[0]=1;
fp(i,1,14)fac[i]=mul(fac[i-1],i);
fp(i,1,14)fp(j,1,i)C[i][j]=add(C[i-1][j],C[i-1][j-1]);
fp(i,15,N-1)fp(j,1,14)C[i][j]=add(C[i-1][j],C[i-1][j-1]);
spj();
return 0;
}