CXLV.[九省联考2018]秘密袭击coat
首先先讲一种暴力但能过的方法。
很容易就会往每个值各被计算几次的方向去想。于是我们枚举每个节点,计算有多少种可能下该节点是目标节点。
为了避免相同的值的影响,我们在值相同的点间也决出一种顺序,即,若两个值相同的点在作比较,依照上文定下的那种顺序决定。
于是我们考虑从该枚举的点 \(x\) 出发,遍历整棵子树同时DP。设 \(f_{i,j}\) 表示 \(i\) 子树中有 \(j\) 个点的危险程度 \(\geq d_x\)。于是就直接背包转移就行了。
看上去复杂度是 \(O(n^3)\),但是加上下述两个优化就可以过了:
-
第二维最大只枚举到 \(m\)(这里的 \(m\) 即题面中的 \(k\),因为 \(k\) 这个字母我们接下来还要用)
-
第二维最大只枚举到子树大小 \(sz\)。
然后就过了,跑的还比正解都要快。
代码:
#include<bits/stdc++.h>
using namespace std;
const int mod=64123;
int n,m,W,d[2010],f[2010][2010],p[2010],q[2010],res,sz[2010];
vector<int>v[2010];
void dfs(int x,int fa,int lim){
for(int i=0;i<=sz[x];i++)f[x][i]=0;
f[x][sz[x]=(q[x]>=lim)]=1;
for(auto y:v[x]){
if(y==fa)continue;
dfs(y,x,lim);
// printf("%d:",x);for(int i=0;i<=m;i++)printf("%d ",f[x][i]);puts("");
// printf("%d:",y);for(int i=0;i<=m;i++)printf("%d ",f[y][i]);puts("");
for(int i=sz[x];i>=0;i--)for(int j=min(m-i,sz[y]);j>=0;j--)(f[x][i+j]+=1ll*f[x][i]*f[y][j]%mod)%=mod;
sz[x]=min(sz[x]+sz[y],m);
// printf("%d:",x);for(int i=0;i<=m;i++)printf("%d ",f[x][i]);puts("\n");
}
}
int main(){
scanf("%d%d%d",&n,&m,&W);
for(int i=1;i<=n;i++)scanf("%d",&d[i]),p[i]=i;
sort(p+1,p+n+1,[](int u,int v){return d[u]<d[v];});
for(int i=1;i<=n;i++)q[p[i]]=i;
for(int i=1,x,y;i<n;i++)scanf("%d%d",&x,&y),v[x].push_back(y),v[y].push_back(x);
for(int i=1;i<=n;i++){
// printf("%d\n",q[i]);
if(q[i]>n-m+1)continue;
dfs(i,0,q[i]);
(res+=1ll*d[i]*f[i][m]%mod)%=mod;
}
printf("%d\n",res);
return 0;
}
然后是正解。
我们要求 \(\sum\limits_{\mathbb S\subseteq\mathbb T}\text{Kth of }\mathbb S\)。
考虑枚举该 \(\text{Kth}\) 的值为 \(i\),则要求 \(\sum\limits_{i=1}i\sum\limits_{\mathbb S\subseteq\mathbb T}[\text{Kth of }\mathbb S=i]\)。
考虑让每个 \(i\) 拆分成在所有的 \(j\leq i\) 的位置上各被计算一次,则要求 \(\sum\limits_{i=1}\sum\limits_{\mathbb S\subseteq\mathbb T}[\text{Kth of }\mathbb S\geq i]\)。(这同时也是期望DP的经典套路)
考虑令 \(cnt_i\) 表示 \(\geq i\) 的数的总数。则要求 \(\sum\limits_{i=1}\sum\limits_{\mathbb S\subseteq\mathbb T}[cnt_i\geq m]\)。(注意这里的 \(m\) 即为 \(k\))
考虑对于每个连通块,在树上最高处计算它的贡献。设 \(f_{i,j,k}\) 表示以 \(i\) 为根的子树内,当前统计的是 \(cnt_j\),且 \(cnt_j=k\) 的方案数。转移是裸的背包卷积。
考虑如何求答案。因为我们是在最高处计算贡献,所以就要求 \(\sum\limits_{i=1}^n\sum\limits_{j=1}^W\sum\limits_{k=m}^nf_{i,j,k}\)。
因为我们是卷积,所以考虑FFT转移。又因为它一直在卷,所以我们干脆考虑压根不把它复原成系数式,就纯粹用点集表示。
更准确地说,因为 \(f_{i,j}\) 的生成函数是 \(\sum\limits_{k=1}^nf_{i,j,k}x^k\),一个 \(n\) 次多项式,所以我们直接枚举 \(x\in[1,n+1]\),然后分别求出这时的生成函数的值,最后拉格朗日插值一下就插回系数式了。
则,在合并父亲 \(x\) 和儿子 \(y\) 的 \(f\) 数组时,因为是点集式,所以直接对应位置相乘就行了。
但是就算有了这么伟大的思路,我们算算复杂度,还是 \(O(n^3)\) 的。有没有更棒的复杂度?
我们发现,最终要对所有 \(x\) 求 \(f\) 数组的和,倒不如在正常处理的过程中就顺便维护了。于是我们设 \(g_{i,j}=\sum\limits_{k\in\text{subtree}_i}f_{k,j}\),则最终要求的就是 \(\sum\limits_{i=m}^ng_{1,i}\)。当然,依据分配律,我们还是可以直接一股脑求出 \(\sum\limits_{i=1}^ng_{1,i}\),待插出系数式后再取 \(\geq m\) 的项。
我们思考当合并 \(f_{x,i}\) 与 \(f_{y,i}\) 时会发生什么:
\(f_{x,i}\rightarrow f_{x,i}\times(f_{y,i}+1)\)(采用 \(f_{y,i}\) 或不用)
\(g_{x,i}\rightarrow g_{x,i}+g_{y,i}\)
在全体结束后,再有 \(g_{x,i}\rightarrow g_{x,i}+f_{x,i}\)。
同时,为了便于合并,我们采用线段树合并来维护DP数组。(这种操作被称作整体DP)
我们考虑初始化,发现假如我们最外层枚举的点值是 \(X\),则所有 \(\forall i\leq d_x\),\(f_{x,i}\) 在结束时都要乘上一个 \(X\)。
明显这个初始状态大体是区间的,非常适合线段树合并。
但是,就算这样,暴力合并的复杂度仍然 \(O(n^3)\),必须考虑在区间上把它做掉。
于是,一个天才般的想法诞生了:
观察到每次的变换都是 \((f,g)\rightarrow(af+b,cf+d+g)\) 的形式。
而这个变换,可以用四元组 \((a,b,c,d)\) 唯一刻画。
同时,展开式子,就会发现这个变换具有结合律(虽然很遗憾的是,大部分情形下它不具有交换律)。
假如我们初始令 \(b=f,d=g\) 的话,就会发现,做一次上述操作,它就自动帮你更新了 \(f\) 与 \(g\)!
于是,我们把它看作区间的 tag
,然后线段树合并就非常简单了。
同时,要注意的是,其单位元是 \((1,0,0,0)\)。
我们来总结一下操作:当合并的时候,我们希望 \(f_x\rightarrow f_x\times(f_y+1)\),而 \(f_y+1\) 可以通过在遍历完 \(y\) 的子树后打上全体 \(+1\) 的 tag
解决,当然这里不需要额外增加其它的 tag
,我们发现 \((1,1,0,0)\) 刚好胜任了这个操作。于是现在 \(f_x\rightarrow f_{x}\times f_y\),\((f_y,0,0,0)\) 的 tag
即可(需要注意的是,\(f_y\) 是线段树上 \(y\) 处的 \(b\))。\(g_{x}\rightarrow g_x+g_y\),\((1,0,0,g_y)\) 即可,而 \(g_y\) 则是 \(d\)。两个乘起来,就是使用 \((b,0,0,d)\)。
最后合并 \(f\) 与 \(g\) 的时候,则要使用 \((1,0,1,0)\),意义通过展开即可得到就是将 \(f\) 加到 \(g\)。而乘上 \(X\) 的操作,使用 \((X,0,0,0)\) 即可。
需要注意的是,这里并不能标记永久化,主要是因为从四元组中抽出 \(b\) 和 \(d\) 的操作并非线性变换,不能打到 tag
上去,在线段树合并的时候要先一路下传到某一方已经没有叶子了再合并。
同时,使用 unsigned int
可以刚好把 64123
卡进一次乘法内不爆。
代码(一些比较疑惑的地方已经加了注释):
#include<bits/stdc++.h>
using namespace std;
#define int unsigned int
const int mod=64123;
int n,m,W,d[2010],X,cnt,bin[5010000],tp,rt[2010],a[2010];
struct dat{//(f,g)->(af+b,cf+d+g)
int a,b,c,d;
dat(){a=1,b=c=d=0;}
dat(int A,int B,int C,int D){a=A,b=B,c=C,d=D;}
friend dat operator*(const dat&u,const dat&v){return dat((u.a*v.a)%mod,(u.b*v.a+v.b)%mod,(u.a*v.c+u.c)%mod,(u.b*v.c+u.d+v.d)%mod);}
void operator*=(const dat&v){(*this)=(*this)*v;}
void print()const{printf("(%u %u %u %u)\n",a,b,c,d);}
};
#define mid ((l+r)>>1)
int newnode(){return tp?bin[tp--]:++cnt;}
struct SegTree{
int lson,rson;
dat tag;
}seg[5010000];
void erase(int &x){if(x)seg[x].tag=dat(),erase(seg[x].lson),erase(seg[x].rson),bin[++tp]=x,x=0;}//erase all the subtree of x.
void pushdown(int x){
if(!seg[x].lson)seg[x].lson=newnode();
if(!seg[x].rson)seg[x].rson=newnode();
seg[seg[x].lson].tag*=seg[x].tag,seg[seg[x].rson].tag*=seg[x].tag,seg[x].tag=dat();
}
void modify(int &x,int l,int r,int L,int R,dat val){
if(l>R||r<L)return;
if(!x)x=newnode();
if(L<=l&&r<=R){seg[x].tag*=val;return;}
pushdown(x),modify(seg[x].lson,l,mid,L,R,val),modify(seg[x].rson,mid+1,r,L,R,val);
}
void merge(int &x,int &y){
if(!seg[x].lson&&!seg[x].rson)swap(x,y);
if(!seg[y].lson&&!seg[y].rson){seg[x].tag*=dat(seg[y].tag.b,0,0,seg[y].tag.d);return;}
pushdown(x),pushdown(y),merge(seg[x].lson,seg[y].lson),merge(seg[x].rson,seg[y].rson);
}
int query(int x,int l,int r){
if(l==r)return seg[x].tag.d;
pushdown(x);
return (query(seg[x].lson,l,mid)+query(seg[x].rson,mid+1,r))%mod;
}
void iterate(int x,int l,int r){
if(!x)return;
printf("%u:[%u,%u]\n",x,l,r);seg[x].tag.print();
iterate(seg[x].lson,l,mid),iterate(seg[x].rson,mid+1,r);
}
vector<int>v[2010];
void dfs(int x,int fa){
modify(rt[x],1,W,1,W,dat(0,1,0,0));//set all into (0,1,0,0),which means only f=1.
for(auto y:v[x])if(y!=fa)dfs(y,x),merge(rt[x],rt[y]),erase(rt[y]);
modify(rt[x],1,W,1,d[x],dat(X,0,0,0));//those <=d[x] are multiplied by an X
modify(rt[x],1,W,1,W,dat(1,1,1,0));
//product of (1,0,1,0) and (1,1,0,0), first means add f to g(to calculate the real g), second means add 1 to f (stands for x itself not chosen at x's father)
}
int all[2010],tmp[2010],res;
int ksm(int x,int y=mod-2){int z=1;for(;y;y>>=1,x=x*x%mod)if(y&1)z=z*x%mod;return z;}
void Lagrange(){
all[0]=1;
for(int i=1;i<=n+1;i++)for(int j=i-1;j<=i;j--)(all[j+1]+=all[j])%=mod,(all[j]*=mod-i)%=mod;//note that j is unsigned!!!
// for(int i=0;i<=n+1;i++)printf("%u ",all[i]);puts("");
for(int i=1;i<=n+1;i++){
int inv=ksm(mod-i),sum=0;
for(int j=0;j<=n;j++)tmp[j]=all[j];
for(int j=0;j<=n;j++)(tmp[j]*=inv)%=mod,(tmp[j+1]+=mod-tmp[j])%=mod;
// if(i>=1410){for(int j=0;j<=n;j++)printf("%u ",tmp[j]);puts("");}
for(int j=m;j<=n;j++)sum+=tmp[j];sum%=mod;
for(int j=1;j<=n+1;j++)if(j!=i)(sum*=ksm((i-j+mod)%mod))%=mod;
res+=sum*a[i]%mod;
}
res%=mod;
}
signed main(){
scanf("%u%u%u",&n,&m,&W);
for(int i=1;i<=n;i++)scanf("%u",&d[i]);
for(int i=1,x,y;i<n;i++)scanf("%u%u",&x,&y),v[x].push_back(y),v[y].push_back(x);
for(X=1;X<=n+1;X++)dfs(1,0),a[X]=query(rt[1],1,W),erase(rt[1]);
// for(int i=1;i<=n+1;i++)printf("%u ",a[i]);puts("");
Lagrange();printf("%u\n",res);
return 0;
}