zoukankan      html  css  js  c++  java
  • bzoj 4911: [Sdoi2017]切树游戏

    题目描述

    Solution

    考虑暴力DP:设 (f[x][i]) 表示 (x) 子树内, (x) 作为深度最小的点的连通块的数量
    (f[x][i]=f[x][j]*f[u][k]\,j igoplus k=i)
    这个过程可以用 (FWT) 优化

    由于有修改,用链分治动态维护这个DP
    按树链剖分的方法,把树分成若干条重链
    每一条重链看作一个序列 (P_L,...P_R),按照深度从 (L)(R) 递减的顺序排列,线段树维护

    分别记录以下东西:
    (sum[x][i]) 表示线段树中 (x) 所代表的区间的异或和为 (i) 的连通块的答案和
    (li[x][i]) 表示 线段树中 (x) 所代表的区间中包含左端点的异或和为 (i) 的连通块的答案和
    (ri[x][i]) 表示 线段树中 (x) 所代表的区间中包含右端点的异或和为 (i) 的连通块的答案和
    (siz[x][i]) 表示 线段树中 (x) 所代表的区间 ([L,R]) 这个完整的异或和为 (i) 的连通块的答案(也就是每一个位置权值的乘积)

    同一条链的转移十分简单,考虑链与链之间的转移:
    我们把这一条链直接当作 链顶的父亲 的权值就行了
    更新的时候在链上暴力跳就行了
    复杂度是 (log^2)

    考虑这个转移是需要 (FWT) 优化的,复杂度又多了个 (log)

    有一种方法优化:
    我们 (FWT) 时,是先 (FWT(a,1)),再做点值多项式乘法,再转回来的过程
    我们可以一开始就转好点值多项式,然后运算过程全程用点值多项式的值来代入,中间的运算过程就可以变成普通的点值乘法了
    在询问的时候再 (FWT) 回来就行了

    这样复杂度就是 (O(n*m*log^2)) 的了

    另外注意:
    (0) 没有逆元,由于会除以 (0),所以要定义一种新运算维护 (0) 的个数,重载一下乘除号就行了

    #include<bits/stdc++.h>
    #define pb push_back
    using namespace std;
    const int N=30005,M=130,mod=10007;
    int n,m,Q,a[N],sz[N],son[N],dep[N],head[N],nxt[N*2],to[N*2],num=0;
    int top[N],fa[N],inv[N],E[M][M],lis[N],tt=0,ans[M],re[M];
    vector<int>v[N];
    inline void link(int x,int y){nxt[++num]=head[x];to[num]=y;head[x]=num;}
    inline void dfs(int x){
    	sz[x]=1;
    	for(int i=head[x];i;i=nxt[i]){
    		int u=to[i];if(sz[u])continue;
    		dep[u]=dep[x]+1;fa[u]=x;dfs(u);
    		sz[x]+=sz[u];if(sz[u]>sz[son[x]])son[x]=u;
    	}
    }
    inline void dfs2(int x,int tp){
    	top[x]=tp;
    	if(son[x])dfs2(son[x],tp);
    	for(int i=head[x];i;i=nxt[i])
    		if(to[i]!=fa[x] && to[i]!=son[x])dfs2(to[i],to[i]);
    	v[tp].pb(x);
    }
    inline void fwt(int *A,int o){
    	for(int i=1;i<m;i<<=1)
    		for(int j=0;j<m;j+=i<<1)
    			for(int k=0;k<i;k++){
    				int x=A[j+k],y=A[j+k+i];
    				if(!o)A[j+k]=(x+y)%mod,A[j+k+i]=(x-y+mod)%mod;
    				else A[j+k]=(x+y)*inv[2]%mod,A[j+k+i]=(x-y+mod)*inv[2]%mod;
    			}
    }
    struct data{
    	int a,b;
    	inline void biu(int x){x%=mod;if(x)a=x,b=0;else a=1,b=1;}
    	inline int val(){return b?0:a;}
    	inline void operator *=(const int x){
    		if(!x)b++;
    		else a=a*x%mod;
    	}
    	inline void operator /=(const int x){
    		if(!x)b--;
    		else a=a*inv[x]%mod;
    	}
    }f[N][M];
    void priwork(){
    	inv[1]=1;
    	for(int i=2;i<mod;i++)inv[i]=(mod-(mod/i)*inv[mod%i]%mod)%mod;
    	int len;for(len=1;len<m;len<<=1);m=len;
    	for(int i=0;i<m;i++)E[i][i]=1,fwt(E[i],0);    //预处理出单位矩阵 E
           //因为我们是先把 f[i][a[i]]=1 赋为 1 再转点值表达式的,我们预处理出E[i]表示把 i 赋成1时的单位多项式
    	for(int i=1;i<=n;i++)
    		for(int j=0;j<m;j++)f[i][j].biu(E[a[i]][j]);
    }
    inline bool comp(int i,int j){return dep[i]>dep[j];}
    int ls[N*4],rs[N*4],rt[N],li[N*4][M],ri[N*4][M];
    int ft[N*4],sum[N*4][M],siz[N*4][M],id[N];
    inline void upd(int o){
    	for(int i=0;i<m;i++){
    		sum[o][i]=(sum[ls[o]][i]+sum[rs[o]][i]+ri[ls[o]][i]*li[rs[o]][i])%mod;
    		li[o][i]=(li[ls[o]][i]+li[rs[o]][i]*siz[ls[o]][i])%mod;
    		ri[o][i]=(ri[rs[o]][i]+ri[ls[o]][i]*siz[rs[o]][i])%mod;
    		siz[o][i]=siz[ls[o]][i]*siz[rs[o]][i]%mod;
    	}
    }
    inline void build(int &x,int l,int r,int t){
    	x=++tt;
    	if(l==r){
    		id[v[t][l]]=x;
    		for(int i=0;i<m;i++)
    			li[x][i]=ri[x][i]=sum[x][i]=siz[x][i]=f[v[t][l]][i].val();
    		return ;
    	}
    	int mid=(l+r)>>1;
    	build(ls[x],l,mid,t);build(rs[x],mid+1,r,t);
    	if(ls[x])ft[ls[x]]=x;if(rs[x])ft[rs[x]]=x;
    	upd(x);
    }
    inline void solve(int x){
    	int t=top[x];
    	if(fa[t])for(int i=0;i<m;i++)f[fa[t]][i]/=(ri[rt[t]][i]+E[0][i])%mod;
    	for(int i=0;i<m;i++)ans[i]=(ans[i]-sum[rt[t]][i]+mod)%mod;
    	int p=id[x];
    	for(int i=0;i<m;i++)
    		li[p][i]=ri[p][i]=sum[p][i]=siz[p][i]=f[x][i].val();
    	for(p=ft[p];p;p=ft[p])upd(p);
    	if(fa[t])for(int i=0;i<m;i++)f[fa[t]][i]*=(ri[rt[t]][i]+E[0][i])%mod;
    	for(int i=0;i<m;i++)ans[i]=(ans[i]+sum[rt[t]][i])%mod;
    }
    int main(){
      freopen("pp.in","r",stdin);
      freopen("pp.out","w",stdout);
      int x,y;char S[8];
      scanf("%d%d",&n,&m);
      for(int i=1;i<=n;i++)scanf("%d",&a[i]);
      for(int i=1;i<n;i++)scanf("%d%d",&x,&y),link(x,y),link(y,x);
      dep[1]=1;dfs(1);dfs2(1,1);
      priwork();
      int cnt=0;
      for(int i=1;i<=n;i++)if(top[i]==i)lis[++cnt]=i;
      sort(lis+1,lis+cnt+1,comp);
      for(int i=1;i<=cnt;i++){
    	  x=lis[i];
    	  build(rt[x],0,v[x].size()-1,x);
    	  if(fa[x])
    		  for(int j=0;j<m;j++)f[fa[x]][j]*=(ri[rt[x]][j]+E[0][j])%mod;
    	  for(int j=0;j<m;j++)ans[j]=(ans[j]+sum[rt[x]][j])%mod;
      }
      cin>>Q;
      while(Q--){
    	  scanf("%s%d",S,&x);
    	  if(S[0]=='Q'){
    		  for(int i=0;i<m;i++)re[i]=ans[i];
    		  fwt(re,1);
    		  printf("%d
    ",re[x]);
    	  }
    	  else{
    		  scanf("%d",&y);
    		  for(int i=0;i<m;i++)f[x][i]/=E[a[x]][i];
    		  a[x]=y;
    		  for(int i=0;i<m;i++)f[x][i]*=E[a[x]][i];
    		  for(;x;x=fa[top[x]])solve(x);
    	  }
      }
      return 0;
    }
    
    
  • 相关阅读:
    tensorflow入门(三)
    tensorflow入门(二)
    setTimeout
    PreResultListener
    sql 删除重复记录
    oracle dual表用途及结构详解
    oracle中的dual表
    Dubbo远程调用服务框架原理与示例
    struts2和spring的两种整合方式
    Timer和TimerTask详解
  • 原文地址:https://www.cnblogs.com/Yuzao/p/8576879.html
Copyright © 2011-2022 走看看