zoukankan      html  css  js  c++  java
  • LOJ#6289. 花朵 树链剖分+分治NTT

    本来以为这道题会非常难调,但是没想到调了不到 5 分钟就 A 了.  

    由于基于多项式的运算都可以方便地进行封装,所以细节就不是很多(或者说几乎没有细节)   

    题意:给定一棵树,每个点有点权,求对于所有大小为 $m$ 的独立集的点权之积的和.     

    数据范围:$n,m leqslant 8 imes 10^4$.  

    先考虑一个十分显然的 $O(n^2)$ 暴力:

    令 $f[x][i],g[x][i]$ 分别表示点 $x$ 选/不选的情况下独立集大小为 $i$ 的点积 之和.  

    考虑将 $x$ 与 $x$ 的一个儿子 $y$ 合并:$f[x][i+j]=f[x][i] imes f[y][j]$,$g$ 同理.  

    然后 $x$ 的初始值是:$f[x][1]=w[x],g[x][0]=1$.    

    树形DP 卡一下上界复杂度是 $O(n^2)$ 的.  

    不难发现,上述 $f[x][i+j] = f[x][i] imes f[y][j]$ 是一个卷积的形式.  

    如果是菊花图或者链的话可以直接用 NTT/分治NTT 来做.   

    正解的话考虑进行轻重路径剖分:   

    对于一条重链来说,先求出该重链中每个点轻儿子为根的多项式 $f,g$,然后对于重链中每个点都将其轻儿子与该点合并.   

    最后对于一条重链进行分治,求出该重链链顶为根的多项式.   

    分析一下时间复杂度: 

    考虑一条重链链顶为根的子树会被卷多少次:其祖先中每一条重链都会将其贡献一次.  

    那么树链剖分中一个点有 $O(log n)$ 个祖先,而每次卷积的时候对链分治的复杂度是 $O(n log^2 n)$.  

    总复杂度就是 $O(n log^3 n)$,但是由于树链剖分的常数比较小,跑的并不慢.   

    code:  

    #include <queue>
    #include <cstdio>   
    #include <vector>
    #include <cstring> 
    #include <algorithm>  
    #define N 1000009 
    #define ll long long 
    #define mod 998244353 
    #define pb push_back
    #define setIO(s) freopen(s".in","r",stdin)  
    using namespace std;  
    int m; 
    int A[N<<2],B[N<<2];      
    int tim,edges,n; 
    int size[N],son[N],top[N],hd[N],to[N<<1],nex[N<<1],fa[N],dep[N]; 
    int dfn[N],bu[N],si[N],val[N];   
    void add(int u,int v) { 
        nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;  
    }
    int ADD(int x,int y) { 
        return (ll)(x+y)%mod; 
    }  
    int DEC(int x,int y) { 
        return (ll)(x-y+mod)%mod; 
    }  
    int MUL(int x,int y) { 
        return (ll)x*y%mod; 
    }
    int qpow(int x,int y) { 
        int tmp=1; 
        for(;y;y>>=1,x=(ll)x*x%mod) {   
            if(y&1) tmp=(ll)tmp*x%mod; 
        }  
        return tmp; 
    }
    int get_inv(int x) { 
        return qpow(x,mod-2); 
    }
    void NTT(int *a,int len,int op) { 
        for(int i=0,k=0;i<len;++i) { 
            if(i>k) { 
                swap(a[i],a[k]); 
            }  
            for(int j=len>>1;(k^=j)<j;j>>=1); 
        }  
        for(int l=1;l<len;l<<=1) { 
            int wn=qpow(3,(mod-1)/(l<<1));  
            if(op==-1) wn=get_inv(wn);  
            for(int i=0;i<len;i+=l<<1) { 
                int w=1;  
                for(int j=0;j<l;++j) { 
                    int x=a[i+j],y=(ll)w*a[i+j+l]%mod;  
                    a[i+j]=(ll)(x+y)%mod;  
                    a[i+j+l]=(ll)(x-y+mod)%mod;  
                    w=(ll)w*wn%mod; 
                }
            }
        }
        if(op==-1) { 
            int iv=get_inv(len); 
            for(int i=0;i<len;++i) { 
                a[i]=(ll)a[i]*iv%mod;   
            }
        }
    }
    struct poly { 
        int len;
        vector<int>a;  
        poly() { len=0,a.clear(); } 
        void push(int x) { 
            a.pb(x),++len;
        }
        void resize(int x) {
            a.resize(x),len=x;    
        }                       
        poly operator*(const poly &b) const { 
            int lim;
            for(lim=1;lim<len+b.len-1;lim<<=1); 
            for(int i=0;i<lim;++i) A[i]=B[i]=0;
            for(int i=0;i<len;++i) A[i]=a[i];
            for(int i=0;i<b.len;++i) B[i]=b.a[i];
            NTT(A,lim,1),NTT(B,lim,1);
            for(int i=0;i<lim;++i) {    
                A[i]=(ll)A[i]*B[i]%mod;
            }
            NTT(A,lim,-1);
            poly c;
            for(int i=0;i<len+b.len-1;++i) { 
                c.push(A[i]); 
            }
            if(c.len>m+1) c.resize(m+1);
            return c;   
        }
        poly operator+(const poly &b) const {
            poly c; 
            c.resize(max(len,b.len));  
            for(int i=0;i<c.len;++i) c.a[i]=0; 
            for(int i=0;i<c.len;++i) {    
                if(i<len) c.a[i]=ADD(c.a[i],a[i]); 
                if(i<b.len) c.a[i]=ADD(c.a[i],b.a[i]);  
            }
            return c;   
        }
        poly operator-(const poly &b) const {    
            poly c;  
            c.resize(max(len,b.len));    
            for(int i=0;i<c.len;++i) c.a[i]=0;
            for(int i=0;i<c.len;++i) { 
                if(i<len) c.a[i]=ADD(c.a[i],a[i]); 
                if(i<b.len) c.a[i]=DEC(c.a[i],b.a[i]);  
            }  
            return c;  
        }
    }f0[N],f1[N],g[2][N];      
    struct data {
        poly f00,f01,f10,f11;           
        data operator+(const data &b) const { 
            data c;   
            c.f00=(f01*b.f00)+(f00*(b.f00+b.f10));   
            c.f11=(f11*b.f01)+(f10*(b.f11+b.f01));    
            c.f01=(f01*b.f01)+(f00*(b.f01+b.f11));      
            c.f10=(f11*b.f00)+(f10*(b.f10+b.f00));    
            return c;  
        }
    }tmp;  
    void dfs1(int x,int ff) {  
        fa[x]=ff,dep[x]=dep[ff]+1,size[x]=1;  
        for(int i=hd[x];i;i=nex[i]) { 
            int y=to[i];  
            if(y==ff) continue;  
            dfs1(y,x);
            size[x]+=size[y];
            if(size[y]>size[son[x]]) son[x]=y;
        }
    }
    void dfs2(int x,int tp) { 
        top[x]=tp;  
        dfn[x]=++tim;
        bu[tim]=x;
        ++si[tp];  
        if(son[x]) {  
            dfs2(son[x],tp); 
        }
        for(int i=hd[x];i;i=nex[i]) {    
            if(to[i]!=fa[x]&&to[i]!=son[x]) { 
                dfs2(to[i],to[i]);  
            }
        }
    }
    poly calc(int l,int r,int d) {     
        if(l==r) {   
            return g[d][l];  
        }
        int mid=(l+r)>>1;  
        return calc(l,mid,d)*calc(mid+1,r,d);  
    }
    data solve(int l,int r) {   
        if(l==r) {      
            int u=bu[l];   
            data e;   
            e.f00=f0[u];  
            e.f11=f1[u];  
            return e;  
        }
        int mid=(l+r)>>1;       
        return solve(l,mid)+solve(mid+1,r);  
    }
    int main() { 
        // setIO("input");  
        int x,y,z; 
        scanf("%d%d",&n,&m);    
        for(int i=1;i<=n;++i) scanf("%d",&val[i]);
        for(int i=1;i<n;++i) {
            scanf("%d%d",&x,&y); 
            add(x,y),add(y,x); 
        }
        dfs1(1,0),dfs2(1,1);       
        for(int i=1;i<=n;++i) {
            f0[i].push(1);  
            f1[i].push(0);  
            f1[i].push(val[i]);    
        }        
        for(int i=n;i>=1;--i) {
            int p=bu[i]; 
            if(top[p]==p) {
                for(int j=dfn[p];j<=dfn[p]+si[p]-1;++j) { 
                    x=bu[j];         
                    int p0=0,p1=0;      
                    for(int e=hd[x];e;e=nex[e]) {
                        y=to[e];  
                        if(y==son[x]||y==fa[x]) continue;            
                        g[0][++p0]=f0[y]+f1[y];   
                        g[1][++p1]=f0[y];  
                    }     
                    if(p0) f0[x]=calc(1,p0,0);  
                    if(p1) f1[x]=f1[x]*calc(1,p1,1); 
                } 
                tmp=solve(dfn[p],dfn[p]+si[p]-1);       
                f0[p]=tmp.f01+tmp.f00;  
                f1[p]=tmp.f10+tmp.f11;             
            }
        }   
        f0[1].resize(m+1); 
        f1[1].resize(m+1);  
        printf("%d
    ",(ll)(f0[1].a[m]+f1[1].a[m])%mod);  
        return 0; 
    }
    

      

  • 相关阅读:
    python+requests——定制请求头——cookie
    python+requests——高级用法——上传文件
    彻底搞定C指针例题
    static_cast, dynamic_cast, reinterpret_cast, const_cast区别比较
    单链表的基本操作
    new int[10]()
    用人单位给计算机系学生的一封信(超长评论版)
    指向二维数组的指针
    《windows程序设计》第一章学习心得
    VS2010编译Lua程序
  • 原文地址:https://www.cnblogs.com/guangheli/p/13375471.html
Copyright © 2011-2022 走看看