zoukankan      html  css  js  c++  java
  • HDU4670 Cube number on a tree 树分治

        人生的第一道树分治,要是早点学我南京赛就不用那么挫了,树分治的思路其实很简单,就是对子树找到一个重心(Centroid),实现重心分解,然后递归的解决分开后的树的子问题,关键是合并,当要合并跨过重心的两棵子树的时候,需要有一个接近O(n)的方法,因为f(n)=kf(n/k)+O(n)解出来才是O(nlogn).在这个题目里其实就是将第一棵子树的集合里的每个元素,判下有没符合条件的,有就加上,然后将子树集合压进大集合,然后继续搞第二棵乃至第n棵.我的过程用了map,合并是nlogn的所以代码速度颇慢,大概6s,题目时限10s,可以改成hash应该会快许多,毕竟用map实在太慢,用vector也可以,具体可以参见挑战程序设计竞赛代码.下面的代码查找重心用了挑战的代码.

    #pragma comment(linker, "/STACK:102400000,102400000")
    #include<iostream>
    #include<cstring>
    #include<string>
    #include<cstdio>
    #include<algorithm>
    #include<map>
    #include<vector>
    #define maxv 50000
    #define ll long long
    using namespace std;
    
    int n,k;
    vector<int> G[maxv+50];
    ll val[maxv+50];
    ll prime[maxv+50];
    ll convert_three(ll v)
    {
        ll bas=1;ll res=0;
        for(int i=0;i<k;++i){
            int num=0;
            while(v%prime[i]==0){
                v/=prime[i];
                num++;
            }
            num%=3;res+=num*bas;
            bas*=3;
        }
        return res;
    }
    
    ll xor(ll x,ll y)
    {
        ll res=0;ll bas=1;
        for(int i=0;i<k;++i){
            res+=((x%3)+(y%3))%3*bas;
            x/=3;y/=3;
            bas*=3;
        }
        return res;
    }
    
    ll inv(ll x)
    {
        ll res=0;ll bas=1;
        for(int i=0;i<k;++i){
            res+=((3-(x%3))%3)*bas;
            x/=3;
            bas*=3;
        }
        return res;
    }
    
    void print(ll x){
        while(x){
            cout<<x%3;
            x/=3;
        }
        cout<<endl;
    }
    
    bool centroid[maxv+50];
    int ssize[maxv+50];
    int ans;
    
    map<ll,int> sta;
    map<ll,int>::iterator it;
    int compute_ssize(int v,int p)
    {
        int c=1;
        for(int i=0;i<G[v].size();++i){
            int w=G[v][i];
            if(w==p||centroid[w]) continue;
            c+=compute_ssize(G[v][i],v);
        }
        ssize[v]=c;
        return c;
    }
    
    pair<int,int> search_centroid(int v,int p,int t)
    {
        pair<int,int> res=make_pair(INT_MAX,-1);
        int s=1,m=0;
        for(int i=0;i<G[v].size();++i){
            int w=G[v][i];
            if(w==p||centroid[w]) continue;
            res=min(res,search_centroid(w,v,t));
            m=max(m,ssize[w]);
            s+=ssize[w];
        }
        m=max(m,t-s);
        res=min(res,make_pair(m,v));
        return res;
    }
    
    void enumerate_mul(int v,int p,ll d,map<ll,int> &ds)
    {
        if(ds.count(d)) ds[d]++;
        else ds[d]=1;
        for(int i=0;i<G[v].size();++i){
            int w=G[v][i];
            if(w==p||centroid[w]) continue;
            enumerate_mul(w,v,xor(d,val[w]),ds);
        }
    }
    
    void solve(int v)
    {
        compute_ssize(v,-1);
        int s=search_centroid(v,-1,ssize[v]).second;
        centroid[s]=true;
        for(int i=0;i<G[s].size();++i){
            if(centroid[G[s][i]]) continue;
            solve(G[s][i]);
        }
        sta.clear();
        sta[val[s]]=1;map<ll,int> tds;
        for(int i=0;i<G[s].size();++i){
            if(centroid[G[s][i]]) continue;
            tds.clear();
            enumerate_mul(G[s][i],s,val[G[s][i]],tds);
            it=tds.begin();
            while(it!=tds.end()){
                ll rev=inv((*it).first);
                if(sta.count(rev)){
                    ans+=sta[rev]*(*it).second;
                }
                ++it;
            }
            it=tds.begin();
            while(it!=tds.end()){
                ll  vv=xor((*it).first,val[s]);
                if(sta.count(vv)){
                    sta[vv]+=(*it).second;
                }
                else{
                    sta[vv]=(*it).second;
                }
                ++it;
            }
        }
        centroid[s]=false;
    }
    
    int main()
    {
        while(cin>>n>>k){
            ans=0;
            for(int i=0;i<k;++i){
                scanf("%I64d",&prime[i]);
            }
            G[0].clear();
            for(int i=1;i<=n;++i){
                scanf("%I64d",&val[i]);
                val[i]=convert_three(val[i]);
                if(val[i]==0) ans++;
                //print(val[i]);
                G[i].clear();
            }
            int u,v;
            for(int i=0;i<n-1;++i){
                scanf("%d%d",&u,&v);
                G[u].push_back(v);
                G[v].push_back(u);
            }
            memset(centroid,0,sizeof(centroid));
            solve(1);
            printf("%d
    ",ans);
        }
        return 0;
    }
    
  • 相关阅读:
    ajax的原理及实现方式
    在linux中添加环境变量
    ftp简单命令
    linux命令之scp
    java中创建对象的方法
    10个调试技巧
    java读取.properties配置文件的几种方法
    Java对象和XML转换
    Java Float类型 减法运算时精度丢失问题
    Java内存分配全面浅析
  • 原文地址:https://www.cnblogs.com/chanme/p/3411639.html
Copyright © 2011-2022 走看看