description
洛谷
给一棵树,要求支持单点权值修改,以及询问树上有多少个连通块的权值异或和恰好为(k)。
答案对(1e4+7)取模。
data range
设(2^m,q1,q2)分别为最大权值,修改次数和询问次数。
solution
首先当然推荐你去看immortalCO的博客
我们显然可以想到一个朴素的(DP)式:
设(f[u][s])表示异或和为(s),且以(u)为深度最浅的点的连通块个数,那么枚举树边((u,v)(v
ot=fa))有
其中(oplus)表示二进制异或。
我们对于每一次修改都这样(DP)一遍,那么复杂度为(O(qn2^{2m}))
可以发现异或卷积可以使用(FWT)优化(考过哦),复杂度降为(O(qnm2^m))
我们发现(FWT)后的数组更新是直接按位相乘/相加,于是我们只要对于初始化数组(FWT)一遍,之后
对于答案数组再(FWT)回去即可,复杂度将为(O(nm2^m+q(n+m)2^m))
考虑优化我们的(DP)式。
考虑(f[u])的生成函数(f_u(x)=sum_{i=0}^{2^m-1}a_ix^i),那么
其中多项式的积定义为异或卷积。
我们知道每次修改只有一条链的(DP)值会被修改,于是考虑用树链剖分维护。
考虑新开一个(g_i(x))表示(i)的子树的(f(x))之和,那么答案即为(g_1(x))。
维护轻儿子((f(x)+1))的乘积(LF_i)和(g(x))的和(LG_i),考虑一条从根到底的链(c_1,c_2,...,c_k),
那么我们有
这是什么?递推式啊!
对于递推式,我们可以使用矩阵快速幂进行优化。
然后树链剖分+线段树维护矩阵就可以了。
时间复杂度为(O(n(m+logn)2^m+q(m+log^2n)2^m))
树剖(log^2n)跑不满,可以过。
最后注意维护(LF_i)的时候答案可能会除(0),维护一下(0)的个数就可以了。
Code
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<iomanip>
#include<cstring>
#include<complex>
#include<vector>
#include<cstdio>
#include<string>
#include<bitset>
#include<ctime>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<set>
#define Cpy(x,y) memcpy(x,y,sizeof(x))
#define Set(x,y) memset(x,y,sizeof(x))
#define FILE "4911"
#define mp make_pair
#define pb push_back
#define RG register
#define il inline
using namespace std;
typedef unsigned long long ull;
typedef vector<int>VI;
typedef long long ll;
typedef double dd;
const int N=30010;
const int M=1e7+10;
const int base=26;
const dd eps=1e-6;
const int inf=1e9;
const ll INF=1ll<<60;
const ll P=100000;
#define mod (10007)
il ll read(){
RG ll data=0,w=1;RG char ch=getchar();
while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
if(ch=='-')w=-1,ch=getchar();
while(ch<='9'&&ch>='0')data=data*10+ch-48,ch=getchar();
return data*w;
}
il void file(){
srand(time(NULL)+rand());
freopen(FILE".in","r",stdin);
freopen(FILE".out","w",stdout);
}
int n,m,q,inv[mod],a[N],tmp[128],I[128];
int head[N],nxt[N<<1],to[N<<1],cnt;
int fa[N],sz[N],son[N],dep[N],top[N],bot[N],w[N],fw[N],cntw;
il void add(int u,int v){to[++cnt]=v;nxt[cnt]=head[u];head[u]=cnt;}
il void upd(int &a,int b){a+=b;if(a>=mod)a-=mod;}
il void dec(int &a,int b){if(b)upd(a,mod-b);}
il void fwt(int *a,int n,int opt){
for(RG int i=1;i<n;i<<=1)
for(RG int j=0,p=i<<1;j<n;j+=p)
for(RG int k=0;k<i;k++){
RG int x=a[j+k],y=a[i+j+k];a[i+j+k]=x;
upd(a[j+k],y);dec(a[i+j+k],y);
if(opt==-1)
a[j+k]=1ll*a[j+k]*inv[2]%mod,a[i+j+k]=1ll*a[i+j+k]*inv[2]%mod;
}
}
struct int0{int v,z;il void init(int x){v=x?x:1;z=x?0:1;}};
int0 operator *(int0 a,int b){b?a.v=1ll*a.v*b%mod:a.z++;return a;}
int0 operator /(int0 a,int b){b?a.v=1ll*a.v*inv[b]%mod:a.z--;return a;}
int0 operator *(int0 a,int0 b){a.v=1ll*a.v*b.v%mod;a.z+=b.z;return a;}
int0 operator /(int0 a,int0 b){a.v=1ll*a.v*inv[b.v]%mod;a.z-=b.z;return a;}
il int ret(int0 a){return a.z?0:a.v;}
int0 LF[N][128];int LH[N][128];
struct matrix{int s[4][128];int* operator [](int x){return s[x];}};
matrix operator *(matrix x,matrix y){
matrix z;
for(RG int i=0;i<m;i++){
z.s[0][i]=1ll*x.s[0][i]*y.s[0][i]%mod;
z.s[1][i]=1ll*x.s[0][i]*y.s[1][i]%mod;upd(z.s[1][i],x.s[1][i]);
z.s[2][i]=1ll*x.s[2][i]*y.s[0][i]%mod;upd(z.s[2][i],y.s[2][i]);
z.s[3][i]=1ll*x.s[2][i]*y.s[1][i]%mod;
upd(z.s[3][i],x.s[3][i]);upd(z.s[3][i],y.s[3][i]);
}
return z;
}
#define ls (i<<1)
#define rs (i<<1|1)
#define mid ((l+r)>>1)
matrix sum[N<<2];
il void update(int i){sum[i]=sum[rs]*sum[ls];}
void insert(int i,int l,int r,int p){
if(l==r){
memset(tmp,0,sizeof(tmp));tmp[a[fw[l]]]=1;fwt(tmp,m,1);
for(RG int j=0,x;j<m;j++){
x=1ll*ret(LF[l][j])*tmp[j]%mod;
sum[i][0][j]=sum[i][1][j]=sum[i][2][j]=sum[i][3][j]=x;
upd(sum[i][3][j],LH[l][j]);
}
return;
}
if(p<=mid)insert(ls,l,mid,p);else insert(rs,mid+1,r,p);update(i);
}
matrix query(int i,int l,int r,int x,int y){
if(x<=l&&r<=y)return sum[i];
if(y<=mid)return query(ls,l,mid,x,y);if(mid<x)return query(rs,mid+1,r,x,y);
return query(rs,mid+1,r,x,y)*query(ls,l,mid,x,y);
}
void dfs1(int u,int ff){
fa[u]=ff;sz[u]=1;son[u]=0;dep[u]=dep[ff]+1;
for(RG int i=head[u];i;i=nxt[i]){
RG int v=to[i];if(v==ff)continue;
dfs1(v,u);sz[u]+=sz[v];if(sz[son[u]]<sz[v])son[u]=v;
}
}
void dfs2(int u,int tp){
top[u]=tp;w[u]=++cntw;fw[cntw]=u;bot[u]=u;
if(son[u]){dfs2(son[u],tp);bot[u]=bot[son[u]];}
for(RG int j=0;j<m;j++)LF[w[u]][j].init(I[j]);
RG matrix r;
for(RG int i=head[u];i;i=nxt[i]){
RG int v=to[i];if(v==fa[u]||v==son[u])continue;
dfs2(v,v);r=query(1,1,n,w[v],w[bot[v]]);
for(RG int j=0;j<m;j++)upd(r.s[2][j],I[j]);
for(RG int j=0;j<m;j++)
LF[w[u]][j]=LF[w[u]][j]*r.s[2][j],upd(LH[w[u]][j],r.s[3][j]);
}
insert(1,1,n,w[u]);
}
il void change(int x,int y){
RG matrix r;
for(RG int u=top[x],ff;fa[u];u=top[fa[u]]){
ff=w[fa[u]];r=query(1,1,n,w[u],w[bot[u]]);
for(RG int j=0;j<m;j++)upd(r.s[2][j],I[j]);
for(RG int j=0;j<m;j++)
LF[ff][j]=LF[ff][j]/r.s[2][j],dec(LH[ff][j],r.s[3][j]);
}
a[x]=y;insert(1,1,n,w[x]);
for(RG int u=top[x],ff;fa[u];u=top[fa[u]]){
ff=w[fa[u]];r=query(1,1,n,w[u],w[bot[u]]);
for(RG int j=0;j<m;j++)upd(r.s[2][j],I[j]);
for(RG int j=0;j<m;j++)
LF[ff][j]=LF[ff][j]*r.s[2][j],upd(LH[ff][j],r.s[3][j]);
insert(1,1,n,ff);
}
}
int main()
{
n=read();m=read();inv[1]=1;I[0]=1;fwt(I,m,1);
for(RG int i=2;i<mod;i++)inv[i]=mod-1ll*(mod/i)*inv[mod%i]%mod;
for(RG int i=1;i<=n;i++)a[i]=read();
for(RG int i=1,u,v;i<n;i++){u=read();v=read();add(u,v);add(v,u);}
dfs1(1,0);dfs2(1,1);RG matrix r;q=read();
for(RG int i=1,c,x,y;i<=q;i++){
c=0;while(c!='Q'&&c!='C')c=getchar();
if(c=='Q'){
x=read();r=query(1,1,n,w[1],w[bot[1]]);
memcpy(tmp,r.s[3],sizeof(tmp));fwt(tmp,m,-1);
printf("%d
",tmp[x]);
}
else{x=read();y=read();change(x,y);}
}
return 0;
}