zoukankan      html  css  js  c++  java
  • 洛谷 P5276 模板题(uoi)

    这题挺恶心的。

    首先一颗树的时候点分加卷积统计答案,注意合并子树时按深度从小到大合并,否则复杂度就爆了。
    我偷懒用size从小到大合并,复杂度应该还是两个log.

    然后考虑万恶的环。
    先随便删掉环上一条边,按照树统计一下答案。
    然后考虑
    必须经过环上该条边的答案但又不经过整个环的答案。

    考虑再钦定一条边不经过,算答案。

    然后递归做就行了。

    最后加上经过整个环的答案。
    时间复杂度(O(n log^2(n)))

    // luogu-judger-enable-o2
    // luogu-judger-enable-o2
    #include <bits/stdc++.h>
    
    using namespace std;
    
    typedef vector<int> poly;
    typedef long long ll;
    poly a,b;
    const int P=1<<17;
    const int M=998244353;
    const int G=3; 
    int rev[P],w[P];
    namespace{
        int add(int x,int y){
            return (x+=y)>=M?x-M:x;
        }
        int sub(int x,int y){
            return (x-=y)<0?x+M:x;
        }
        int mul(int x,int y){
            return (ll)x*y%M;
        }
        int fp(int x,int y){
            int ret=1;
            for (; y; y>>=1,x=mul(x,x))
                if (y&1) ret=mul(ret,x);
            return ret;
        }
    }
    int inv2[30];
    void init(int len){
        for (int i=1; i<len; i<<=1){
            w[i]=1;
            if (i>1) w[i+1]=fp(G,(M-1)/(i<<1));
            for (int j=2; j<i; ++j) w[i+j]=mul(w[i+j-1],w[i+1]);
            //cerr<<w[i]<<" "<<w[i+1]<<" "<<w[i+2]<<endl;
        }
        inv2[0]=1;
        inv2[1]=499122177;
        int bit=1;
        for (int i=4; i<=len; i<<=1){
            ++bit;
            inv2[bit]=mul(inv2[bit-1],inv2[1]);
        }
    }
    
    void NTT(int *a,int len){
        for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i],a[rev[i]]);
        for (int i=1; i<len; i<<=1){
            for (int j=0; j<len; j+=(i<<1)){
                int *l=a+j,*b=l+i,*ww=w+i;
                for (int k=0; k<i; ++k){
                    int y=mul(*b,*(ww++));
                    (*b)=(*l)-y;
                    (*b)+=((*b)>>31)&M;
                    ++b;
                    (*l)+=y-M;
                    (*l)+=((*l)>>31)&M;
                    ++l;
                }
            } 
        }
    } 
    
    void INTT(int *a,int len,int bit){
        reverse(a+1,a+len);
        NTT(a,len);
        int ni=inv2[bit];
        for (int i=0; i<len; ++i) a[i]=mul(a[i],ni); 
    }
    
    poly operator *(const poly &u,const poly &v){
        //cerr<<"mulfff"<<endl;
        if ((ll)u.size()*v.size()<=(u.size()+v.size())*30){
            //cerr<<u.size()<<" "<<v.size()<<endl;
            poly ret(u.size()+v.size()-1);
            for (int i=0; i<u.size(); ++i)
                for (int j=0; j<v.size(); ++j)
                ret[i+j]=add(ret[i+j],mul(u[i],v[j]));
            return ret;
        }
        //cerr<<"?????"<<endl;
        a=u;
        b=v;
        int len=1;
        int bit=0;
        for (; len<a.size()+b.size()-1; len<<=1) ++bit;
        //cerr<<"len"<<len<<" "<<u.size()<<" "<<v.size()<<endl;
        a.resize(len); b.resize(len);
        for (int i=0; i<len; ++i) rev[i]=rev[i>>1]>>1|((i&1)?len>>1:0);
        NTT(a.data(),len);
        NTT(b.data(),len);
        for (int i=0; i<len; ++i) a[i]=mul(a[i],b[i]);
        INTT(a.data(),len,bit);
        a.resize(u.size()+v.size()-1);
        return a;
    }
    poly operator +(const poly &u,const poly &v){
        poly ret(max(u.size(),v.size()));
        for (int i=0; i<ret.size(); ++i){
            int x=(i<u.size()?u[i]:0);
            int y=(i<v.size()?v[i]:0);
            ret[i]=add(x,y);
        }
        return ret;
    }
    void operator +=(poly &u,const poly &v){
        //cerr<<"????"<<endl;
        if (u.size()<v.size()) u.resize(v.size());
        for (int i=0; i<v.size(); ++i) u[i]=add(u[i],v[i]);
        //cerr<<"!!!!"<<endl;
    }
    
    ostream& operator <<(ostream& out,const poly &a){
        for (auto i:a) out<<i<<" ";
        return out<<endl;
    }
    
    void test(){
        poly a({1,2}),b({2,3,2333});
        a=a*b;
        cerr<<a;
    } 
    
    int n,m;
    const int N=100010;
    poly ans;
    namespace solve1{
    
        vector<int> e[N];
        int sz[N],tmp[N],rt;
        void Dfs(int x,int fa){
            sz[x]=1;
            for (auto i:e[x])
                if (i!=fa){
                    Dfs(i,x);
                    sz[x]+=sz[i];
                }
        }
        int calc(int y,int x){
            return max(y-sz[x],tmp[x]);
        }
        void Getrt(int x,int fa,const int totsize){
            //cerr<<"Getrt"<<x<<" "<<fa<<endl;
            tmp[x]=0;
            for (auto i:e[x])
                if (i!=fa){
                    tmp[x]=max(sz[i],tmp[x]);
                    Getrt(i,x,totsize);
                }
            if (calc(totsize,x)<calc(totsize,rt)) rt=x;
        }
        
        void Getdeep(int x,int fa,poly &a,int nowdis){
            //cerr<<"Getdeep"<<x<<" "<<fa<<endl;
            ++a[nowdis];
            for (auto i:e[x])
                if (i!=fa){
                    //cerr<<"???"<<i<<endl;
                    Getdeep(i,x,a,nowdis+1);
                }
        } 
        void df(int x){
            //int t=clock();
            Dfs(x,0);
            rt=x;
            int bbb=sz[rt];
            //cerr<<"bbb"<<bbb<<endl;
            Getrt(x,0,sz[x]);
            //cerr<<"rt"<<rt<<" "<<sz[rt]<<endl;
            //getchar();
            for (auto i:e[rt])
                if (sz[i]>sz[rt]) sz[i]=bbb-sz[rt];
            sort(e[rt].begin(),e[rt].end(),[&](int x,int y){
                return sz[x]<sz[y];
            });
            //cerr<<"???"<<endl;
            poly c,b(1,1);
            for (auto i:e[rt]){
                //cerr<<"son"<<i<<" "<<sz[i]<<endl;
                c.clear();
                c.resize(sz[i]+1);
                Getdeep(i,rt,c,1);
                //cerr<<"Gend"<<c<<endl;
                //cerr<<"mulend"<<c.size()<<endl;
                ans+=b*c;
                //cerr<<"AAAA"<<endl;
                b+=c;
            }
            //cerr<<"ans"<<ans<<endl;
            int fkrt=rt;
            for (auto i:e[fkrt]){
                e[i].erase(find(e[i].begin(),e[i].end(),fkrt));
                df(i);
            }
            //cerr<<"dend"<<endl;
        }
        
        void main(int *fa){
            for (int i=1; i<=n; ++i)
                if (fa[i]){
                    //cerr<<"faf"<<i<<" "<<fa[i]<<endl;
                    e[fa[i]].push_back(i);
                    e[i].push_back(fa[i]);
                }
            df(1);
            ans[0]=n;
        }	
    }
    int vis[N];
    int fa[N],k,f;
    vector<int> g[N]; 
    void noloop(int x){
        //cerr<<"noloop"<<x<<endl;
        vis[x]=1;
        for (auto i:g[x])
            if (!vis[i]){
                fa[i]=x;
                noloop(i);
            }
    }
    void Output(poly &a,int k,int f){
        a.resize(k+1);
        int ans1=0;
        for (auto i:a) ans1=add(ans1,i);
        cout<<ans1<<endl;
        if (f) cout<<a;
    }
    
    int main(){
        init(1<<17);
        ios::sync_with_stdio(0);
        cin.tie(0);
        test();
        cin.ignore(233,'
    ');
        cin>>n>>m>>k>>f;
        //n=100000; m=n-1;
        //k=100000; f=1;
        //cerr<<n<<" "<<m<<" "<<k<<" "<<f<<endl;
        for (int i=1; i<=m; ++i){
            int x,y;
            cin>>x>>y;
            //x=rand()%i+1; y=i+1;
            //cerr<<"add"<<x<<" "<<y<<endl;
            g[x].push_back(y);
            g[y].push_back(x);
        }
        noloop(1);
        //cerr<<"What's the fuck?"<<endl;
        solve1::main(fa);
        if (m==n-1){
            Output(ans,k,f);
            return 0;
        }
        //Output(ans,k,f);
        poly s;
        int pp=0;
        function<void(int,int)> findloop=[&](int x,int f){
            vis[x]=2;
            s.push_back(x);
            for (auto i:g[x])
                if (i!=f){
                    if (vis[i]!=2) findloop(i,x);
                    else pp=i;
                    if (pp) return;
                }
            s.pop_back();
        };
        findloop(1,0);
        s.erase(s.begin(),find(s.begin(),s.end(),pp));
        //cerr<<"cut"<<s.front()<<" "<<s.back()<<endl;
        g[s.front()].erase(find(g[s.front()].begin(),g[s.front()].end(),s.back()));
        g[s.back()].erase(find(g[s.back()].begin(),g[s.back()].end(),s.front()));
        function<void(int,int,poly&,int)> ddd=[&](int x,int fa,poly &c,int dis){
            if (dis>=c.size()) c.resize(dis+1);
            ++c[dis];
            for (auto j:g[x])
                if (j!=fa) ddd(j,x,c,dis+1);
        };
        auto Fakeadd=[&](poly &u,const poly &v,int len){
            if (u.size()<v.size()+len) u.resize(v.size()+len);
            for (int i=0; i<v.size(); ++i) u[i+len]=add(u[i+len],v[i]);
        };
        auto waylength=[&](int x,int y){
            return y-x;
        };
        function<void(int,int,int)> solve=[&](int l,int r,int nowlen){
            //cerr<<"solve"<<l<<" "<<r<<" "<<nowlen<<endl;
            if (l==r) return;
            //valid l~r point
            int mid=(l+r)>>1;
            //cut mid mid+1
            //cerr<<"cut"<<s[mid]<<" "<<s[mid+1]<<endl;
            g[s[mid]].erase(find(g[s[mid]].begin(),g[s[mid]].end(),s[mid+1]));
            g[s[mid+1]].erase(find(g[s[mid+1]].begin(),g[s[mid+1]].end(),s[mid]));
            //cerr<<"!!!"<<endl;
            poly c,d;
            ddd(s[l],0,c,0);
            ddd(s[r],0,d,0);
            //cerr<<"???"<<c<<" "<<d<<" "<<"noewln"<<nowlen<<endl;
            Fakeadd(ans,c*d,nowlen);
            //cerr<<"ANS"<<ans<<endl;
            solve(l,mid,waylength(mid,r)+nowlen);
            solve(mid+1,r,waylength(l,mid+1)+nowlen);
        };
        solve(0,s.size()-1,1);
        for (auto i:s){
            poly c;
            ddd(i,0,c,0);
            c[0]=0;
            Fakeadd(ans,c,s.size());
        }
        ans[s.size()]=add(ans[s.size()],1);
        Output(ans,k,f);
    }
    
  • 相关阅读:
    java 8新特性 匿名内部类的使用
    java 8新特性
    jmeter 性能测试
    idea 背景颜色设置
    SpringBoot yaml的配置及使用
    idea 类图显示
    SpringSecurity 获取认证信息 和 认证实现
    MySQL-慢查询日志
    微信小程序领取卡券
    ThinkPhp5-PHPExcel导出|导入 数据
  • 原文地址:https://www.cnblogs.com/Yuhuger/p/10621956.html
Copyright © 2011-2022 走看看