题目大意:有一棵n个节点的树,给每个节点分配一个非负整数,使得权值和为m,求出所有方案的 标号最小的带权重心的 标号之和。一个点是带权重心当且仅当以它为根的子树中,所有子树的权值和小于等于m除以2下取整。n=200000,m=5000000,对质数取模。
思考:
首先比较容易发现的一点是 带权重心一定组成一条链。对于m是奇数的情况,带权重心仅有一个,因为从某个带权重心移动必然会导致某一个子树的权值和大于m/2。我们枚举成为带权重心的点,并用容斥的方法求出这个点在多少种方案中出现。令$get(x,y)$表示将x个无区别的球放入y个有区别的盒子并且没有数量限制的方案数,那么$get(x,y)= binom{x+y-1}{x}$。注意到对于u号点不合法的情况,当且仅当某一个子树中的权值和大于等于$frac{m+1}{2}$,于是我们在u号点的方案数,等于$get(m,n)-sum_{vin son}{sum_{i=frac{m+1}{2}}^{m}{get(i,s_{v})*get(m-i,n-s_{v})}}$,其中$s_{v}$是以u为根时,子树v的大小。
我们换个角度思考后面的和式(这一步直接抛弃了前面的式子):如果我们在子树中先放入$frac{m+1}{2}$个球,那么剩下的球任意放都能被统计到答案里。因此我们可以先确定第$frac{m+1}{2}$个球放在了哪个盒子(相当于节点)里,那么方案数就为$get(m,n)-sum_{vin son}{sum_{i=1}^{s_{v}}{get(frac{m+1}{2}-1,i)*get(frac{m+1}{2}-1,n-i+1)}}$。这个式子把s拿了出来!因此可以算出它的前缀和,在O(m)的时间内完成。
对于m为偶数的情况,我们先将m-1带入奇数算法,这样算出的是没有标号最小限制的方案和。对于一条链(长度大于等于2),它上面的标号不是最小的点会被记重若干次,这个值只和链的两个端点的子树大小有关,因此考虑点分治来解决。每次算贡献时要计算路径上点权最小的点,把当前分治到的联通块的点权从小到大保存下来就不需要用数据结构了。复杂度O(nlogn+m)。
1 #include<bits/stdc++.h> 2 #define mod 998244353 3 using namespace std; 4 typedef long long int ll; 5 const int maxn=2E5+5; 6 const int limit=6E6+5; 7 int n,m; 8 ll ans,fac[limit+5],inv[limit+5],preF[maxn]; 9 int size,head[maxn]; 10 int root,sum[maxn],maxp[maxn],fa[maxn],preSum[maxn]; 11 struct edge 12 { 13 int to,next; 14 }E[maxn*2]; 15 inline void add(int u,int v) 16 { 17 E[++size].to=v; 18 E[size].next=head[u]; 19 head[u]=size; 20 } 21 inline ll qpow(ll x,ll y) 22 { 23 ll ans=1,base=x; 24 while(y) 25 { 26 if(y&1) 27 ans=ans*base%mod; 28 base=base*base%mod; 29 y>>=1; 30 } 31 return ans; 32 } 33 inline ll C(int x,int y) 34 { 35 if(x<y||x<0||y<0) 36 return 0; 37 return fac[x]*inv[y]%mod*inv[x-y]%mod; 38 } 39 void dfs(int u,int F) 40 { 41 fa[u]=F; 42 preSum[u]=1; 43 for(int i=head[u];i;i=E[i].next) 44 { 45 int v=E[i].to; 46 if(v==F) 47 continue; 48 dfs(v,u); 49 preSum[u]+=preSum[v]; 50 preF[u]=(preF[u]-C(m/2+preSum[v]-1,m/2))%mod; 51 } 52 preF[u]=(preF[u]-C(m/2+n-preSum[u]-1,m/2))%mod; 53 } 54 inline void init() 55 { 56 fac[0]=1; 57 for(int i=1;i<=limit;++i) 58 fac[i]=fac[i-1]*i%mod; 59 inv[limit]=qpow(fac[limit],mod-2); 60 for(int i=limit-1;i>=0;--i) 61 inv[i]=inv[i+1]*(i+1)%mod; 62 } 63 inline ll getF(int u,int v) 64 { 65 if(v==fa[u]) 66 return (preF[u]+C(m/2+n-preSum[u]-1,m/2)+C(m/2+preSum[u]-1,m/2))%mod; 67 return (preF[u]+C(m/2+preSum[v]-1,m/2)+C(m/2+n-preSum[v]-1,m/2))%mod; 68 } 69 namespace work1 70 { 71 ll f[maxn]; 72 inline ll get(int x,int y) 73 { 74 return C(x+y-1,x); 75 } 76 void init(int u,int F) 77 { 78 sum[u]=1; 79 for(int i=head[u];i;i=E[i].next) 80 { 81 int v=E[i].to; 82 if(v==F) 83 continue; 84 init(v,u); 85 sum[u]+=sum[v]; 86 } 87 } 88 void dfs(int u,int F,int tot) 89 { 90 vector<int>wait; 91 int g=0; 92 for(int i=head[u];i;i=E[i].next) 93 { 94 int v=E[i].to; 95 if(v==F) 96 continue; 97 g+=sum[v]; 98 wait.push_back(sum[v]); 99 } 100 wait.push_back(tot); 101 ll s=get(m,n); 102 for(int i=0;i<wait.size();++i) 103 s-=f[wait[i]]; 104 s%=mod; 105 ans=(ans+s*u)%mod; 106 for(int i=head[u];i;i=E[i].next) 107 { 108 int v=E[i].to; 109 if(v==F) 110 continue; 111 dfs(v,u,tot+g-sum[v]+1); 112 } 113 } 114 inline void main(int l) 115 { 116 for(int i=1;i<=n;++i) 117 f[i]=(f[i-1]+get(l-1,i)*get(m-l,n-i+1))%mod;// !!!!!!! 118 init(1,0); 119 dfs(1,0,0); 120 } 121 } 122 namespace work2 123 { 124 bool vis[maxn]; 125 int TI,visT[maxn],fa[maxn]; 126 ll f[maxn]; 127 void get(int u,int F) 128 { 129 fa[u]=F; 130 sum[u]=1; 131 for(int i=head[u];i;i=E[i].next) 132 { 133 int v=E[i].to; 134 if(vis[v]||v==F) 135 continue; 136 get(v,u); 137 sum[u]+=sum[v]; 138 } 139 } 140 void getRoot(int u,int F,int tot) 141 { 142 maxp[u]=0; 143 for(int i=head[u];i;i=E[i].next) 144 { 145 int v=E[i].to; 146 if(v==F||vis[v]) 147 continue; 148 getRoot(v,u,tot); 149 maxp[u]=max(maxp[u],sum[v]); 150 } 151 maxp[u]=max(maxp[u],tot-sum[u]); 152 if(maxp[root]>maxp[u]) 153 root=u; 154 } 155 int totC,bel[maxn]; 156 ll sumF[maxn]; 157 void getFF(int u,int F,int c) 158 { 159 bel[u]=c; 160 f[u]=getF(u,F); 161 for(int i=head[u];i;i=E[i].next) 162 { 163 int v=E[i].to; 164 if(v==F||vis[v]) 165 continue; 166 getFF(v,u,c); 167 } 168 sumF[c]=(sumF[c]+f[u])%mod; 169 } 170 void cut(int u,int F,ll base,ll now) 171 { 172 ans=(ans-base*f[u]%mod*now)%mod; 173 for(int i=head[u];i;i=E[i].next) 174 { 175 int v=E[i].to; 176 if(vis[v]||v==F) 177 continue; 178 cut(v,u,base,now+v); 179 } 180 } 181 ll fill(int u,int F,int c) 182 { 183 if(visT[u]==c) 184 return 0; 185 visT[u]=c; 186 ll s=f[u]; 187 for(int i=head[u];i;i=E[i].next) 188 { 189 int v=E[i].to; 190 if(v==F||vis[v]||visT[v]==c) 191 continue; 192 s+=fill(v,u,c); 193 } 194 sumF[bel[u]]=(sumF[bel[u]]-f[u])%mod; 195 return s%mod; 196 } 197 vector<int>wait[maxn]; 198 int what[maxn]; 199 void solve(int u,vector<int>D) 200 { 201 vis[u]=1; 202 ++TI; 203 ll totF=0; 204 sum[u]=0; 205 for(int i=head[u];i;i=E[i].next) 206 { 207 int v=E[i].to; 208 if(vis[v]) 209 continue; 210 ++totC; 211 get(v,u); 212 getFF(v,u,totC); 213 totF=(totF+sumF[totC])%mod; 214 what[totC]=v; 215 sum[u]+=sum[v]; 216 } 217 ll s=0; 218 for(int i=head[u];i;i=E[i].next) 219 { 220 int v=E[i].to; 221 if(vis[v]) 222 continue; 223 ll now=getF(u,v); 224 cut(v,u,(totF-sumF[bel[v]]+now)%mod,v); 225 ans=(ans-s*sumF[bel[v]]%mod*u)%mod; 226 ans=(ans-now*u%mod*sumF[bel[v]])%mod; 227 s=(s+sumF[bel[v]])%mod; 228 } 229 int now=n+1; 230 for(int i=0;i<D.size();++i) 231 { 232 int pos=D[i]; 233 if(pos==u) 234 { 235 now=u; 236 continue; 237 } 238 totF-=sumF[bel[pos]]; 239 ll x=fill(pos,fa[pos],TI); 240 ans=(ans+totF*x%mod*min(pos,now))%mod; 241 242 ans=(ans+x*min(pos,now)%mod*getF(u,what[bel[pos]]))%mod; 243 totF+=sumF[bel[pos]]; 244 totF%=mod; 245 wait[bel[pos]].push_back(pos); 246 } 247 for(int i=head[u];i;i=E[i].next) 248 { 249 int v=E[i].to; 250 if(vis[v]) 251 continue; 252 root=0; 253 get(v,u); 254 getRoot(v,u,sum[v]); 255 solve(root,wait[bel[v]]); 256 } 257 } 258 inline void main() 259 { 260 dfs(1,0); 261 maxp[0]=n+1; 262 get(1,0); 263 getRoot(1,0,n); 264 vector<int>D; 265 for(int i=1;i<=n;++i) 266 D.push_back(i); 267 solve(root,D); 268 } 269 } 270 inline void solve() 271 { 272 work1::main(m/2+1); 273 if(m%2==0) 274 work2::main(); 275 ans=(ans%mod+mod)%mod; 276 cout<<ans<<endl; 277 } 278 int main() 279 { 280 ios::sync_with_stdio(false); 281 cin>>n>>m; 282 init(); 283 for(int i=2;i<=n;++i) 284 { 285 int x,y; 286 cin>>x>>y; 287 add(x,y); 288 add(y,x); 289 } 290 solve(); 291 return 0; 292 }