zoukankan      html  css  js  c++  java
  • Contest 高数题 樹的點分治 樹形DP

    高数题

    HJA最近在刷高数题,他遇到了这样一道高数题。这道高数题里面有一棵N个点的树,树上每个点有点权,每条边有颜色。一条路径的权值是这条路径上所有点的点权和,一条合法的路径需要满足该路径上任意相邻的两条边颜色都不相同。问这棵树上所有合法路径的权值和是多少

    输入第一行一个整数N,代表树上有多少个点。
    接下来一行N个整数,代表树上每个点的权值。
    接下来N-1行,每行三个整数S、E、C,代表S与E之间有一条颜色为C的边。输出一行一个整数,代表所求的值。样例输入

    6
    6 2 3 7 1 4
    1 2 1
    1 3 2
    1 4 3
    2 5 1
    2 6 2

    样例输出

    134

    提示

    对与30%的数据,1≤N≤1000。 
    对于另外20%的数据,可用的颜色数不超过109且随机数据。
    对于另外20%的数据,树的形态为一条链。
    对于100%的数据,1≤N≤3*105,可用的颜色数不超过109,所有点权的大小不超过105。

    這道題簡單的說是一個樹形DP,由下至上分別統計路徑條數,而考場上我想都沒想就開始寫樹的點分治,當天狀態不佳,加之本身的不熟練,沒能按時把點分治寫出來。

    唯一需要注意的是n=3*10^5,dfs足以使程序崩掉,以後最好改寫bfs

    #include<iostream>
    #include<cstdio>
    #include<cstdlib>
    #include<cstring>
    #include<ctime>
    #include<cmath>
    #include<algorithm>
    #include<set>
    #include<map>
    #include<vector>
    #include<string>
    #include<queue>
    #include<stack>
    using namespace std;
    #ifdef WIN32
    #define LL "%I64d"
    #else
    #define LL "%lld"
    #endif
    #define MAXN 410000
    #define MAXV MAXN*2
    #define MAXE MAXV*2
    #define INF 0x3f3f3f3f
    #define PROB "gaoshu"
    typedef long long qword;
    int nextInt()
    {
            char ch;
            int x=0;
            while (ch=getchar(),ch < '0' || ch > '9' );
            //        cout<<ch;
            do
                    x=x*10+ch-'0';
            while (ch=getchar(),ch<='9' && ch>='0');
            return x;
    }
    int n;
    struct Edge
    {
            int np,col;
            int val;
            pair<qword,qword> val2;
            Edge *next,*neg;
            int disable;
    }E[MAXE],*V[MAXV];
    qword val[MAXN];
    int tope=-1;
    int bad[MAXN][3];
    int size[MAXN];
    //int fa[MAXN],depth[MAXN];
    //int jump[20][MAXN];
    
    void addedge(int x,int y,int z)
    {
    //        cout<<"Add:"<<x<<" "<<y<<endl;
            E[++tope].np=y;
            E[tope].col=z;
            E[tope].next=V[x];
            V[x]=&E[tope];
    
            E[++tope].np=x;
            E[tope].col=z;
            E[tope].next=V[y];
            V[y]=&E[tope];
    
            E[tope].neg=&E[tope-1];
            E[tope-1].neg=&E[tope];
    }
    /*
    void dfs1(int now,int d)
    {
            size[now]=1;depth[now]=d;
            Edge *ne;
            for (ne=V[now];ne;ne=ne->next)
            {
                    if (ne->np==fa[now])continue;
                    fa[ne->np]=now;
                    dfs1(ne->np,d+1);
                    size[now]+=size[ne->np];
            }
    }
    void init_lca()
    {
            int i,j;
            for (i=1;i<=n;i++)
            {
                    jump[0][i]=fa[i];
            }
            for (j=1;j<20;j++)
            {
                    for (i=1;i<=n;i++)
                    {
                            jump[j][i]=jump[j-1][jump[j-1][i]];
                    }
            }
    }
    void swim(int &now,int len)
    {
            int i=0;
            while (len)
            {
                    if (len&1)now=jump[i][now];
                    i++;
            }
    }
    int lca(int x,int y)
    {
            if (depth[x]>depth[y])
            {
                    swim(x,depth[x]-depth[y]);
            }else
            {
                    swim(y,depth[y]-depth[x]);
            }
            int i;
            if (x==y)return x;
            for (i=19;i>=0;i--)
            {
                    if (jump[i][x]!=jump[i][y])
                    {
                            x=jump[i][x];
                            y=jump[i][y];
                    }
            }
            return fa[x];
    }*/
    int bcore,vcore;
    int size2[MAXN];
    int get_core_sizt;
    int gc_sizf[MAXN];
    int gc_mxsz[MAXN];
    int gc_col[MAXN];
    int gc_now[MAXN];
    Edge *nel[MAXN];
    int get_core(int dep=1)
    {
            gc_sizf[dep]=get_core_sizt-1;
            gc_mxsz[dep]=0;
            size2[gc_now[dep]]=1;
            for (nel[dep]=V[gc_now[dep]];nel[dep];nel[dep]=nel[dep]->next)
            {
                    if (nel[dep]->np==gc_now[dep-1] || nel[dep]->disable)continue;
                    gc_col[dep+1]=nel[dep]->col;
                    gc_now[dep+1]=nel[dep]->np;
                    get_core(dep+1);
                    size2[gc_now[dep]]+=size2[nel[dep]->np];
                    gc_sizf[dep]-=size2[nel[dep]->np];
                    gc_mxsz[dep]=max(gc_mxsz[dep],size2[nel[dep]->np]);
            }
            gc_mxsz[dep]=max(gc_mxsz[dep],gc_sizf[dep]);
            if (gc_mxsz[dep]<vcore)
            {
                    vcore=gc_mxsz[dep];
                    bcore=gc_now[dep];
            }
    }
    pair<qword,qword> dfs2_ret[MAXN],dfs2_tt[MAXN];
    int dfs2_now[MAXN],dfs2_col[MAXN];
    pair<qword,qword> dfs2(int dep=1)
    {
    //        pair<qword,qword> ret,tt;
            dfs2_ret[dep]=make_pair(val[dfs2_now[dep]],1);
            for (nel[dep]=V[dfs2_now[dep]];nel[dep];nel[dep]=nel[dep]->next)
            {
                    if (nel[dep]->np==dfs2_now[dep-1] || nel[dep]->disable)continue;
                    dfs2_now[dep+1]=nel[dep]->np;
                    dfs2_col[dep+1]=nel[dep]->col;
                    dfs2_tt[dep]=dfs2(dep+1);
                    dfs2_ret[dep].second+=dfs2_tt[dep].second*(nel[dep]->col!=dfs2_col[dep]);
                    dfs2_ret[dep].first+=(dfs2_tt[dep].first+dfs2_tt[dep].second*val[dfs2_now[dep]]) *(nel[dep]->col!=dfs2_col[dep]);
            }
            return dfs2_ret[dep];
    }
    qword solve(int root,int siz)
    {
            if (siz==1)return 0;
            vcore=INF;
            gc_now[0]=root;
            gc_col[1]=INF;
            gc_now[1]=root;
            get_core_sizt=siz;
            get_core(1);
            int core=bcore;
            gc_now[0]=core;
            gc_now[1]=core;
            get_core(1);
            qword ans=0;
            Edge *ne;
            for (ne=V[core];ne;ne=ne->next)
            {
                    if (ne->disable)continue;
                    ne->disable=core;
                    ne->neg->disable=core;
                    ne->val=size[ne->np];
            }
            for (ne=V[core];ne;ne=ne->next)
            {
                    if (ne->disable!=core)continue;
                    ans+=solve(ne->np,size2[ne->np]);
            }
            for (ne=V[core];ne;ne=ne->next)
            {
                    if (ne->disable!=core)continue;
                    dfs2_now[0]=dfs2_now[1]=ne->np;
                    dfs2_col[1]=ne->col;
                    ne->val2=dfs2(1);
            }
            Edge *ne2;
            int t=0;
            map<int,qword> mp;
            pair<qword,qword> sum=make_pair(0,0);
            for (ne=V[core];ne;ne=ne->next)
            {
                    if (ne->disable!=core)continue;    
                    sum.first+=ne->val2.first;
                    sum.second+=ne->val2.second;
                    mp[ne->col]+=ne->val2.second;
            }
            qword ans2=0;
            for (ne=V[core];ne;ne=ne->next)
            {
                    if (ne->disable!=core)continue;
                    ans+=ne->val2.first*(sum.second-mp[ne->col]);//分居兩邊
                    ans2+=val[core]*ne->val2.second*(sum.second-mp[ne->col]);//分局兩邊,中心貢獻
            }
            ans+=ans2/2;
            for (ne=V[core];ne;ne=ne->next)
            {
                    if (ne->disable!=core)continue;
                    t+=ne->val2.second;//中心出發條數
                    ans+=ne->val2.first;//中心出發,外點貢獻
            }
            ans+=val[core]*t;
            for (ne=V[core];ne;ne=ne->next)
            {
                    if (ne->disable==core)
                            ne->disable=ne->neg->disable=0;
            }
            return ans;
    }
    int main()
    {
            //freopen("input.txt","r",stdin);
            //freopen("output.txt","w",stdout);
            freopen(PROB".in","r",stdin);
            freopen(PROB".out","w",stdout);
            int i,j,k;
            int x,y,z;
            //scanf("%d",&n);
            n=nextInt();
            for (i=1;i<=n;i++)
                    val[i]=nextInt();//scanf("%d",&val[i]);
            for (i=1;i<n;i++)
            {
                    //scanf("%d%d%d",&x,&y,&z);
                    x=nextInt();
                    y=nextInt();
                    z=nextInt();
                    addedge(x,y,z);
            }
    //        fa[1]=1;
    //        dfs1(1,0);
    //        init_lca();
            qword ans=solve(1,n);
            printf(LL "
    ",ans);
            return 0;
    
    }
    by mhy12345(http://www.cnblogs.com/mhy12345/) 未经允许请勿转载

    本博客已停用,新博客地址:http://mhy12345.xyz

  • 相关阅读:
    bzoj3675 [Apio2014]序列分割
    bzoj4010 [HNOI2015]菜肴制作
    bzoj4011 [HNOI2015]落忆枫音
    bzoj100题
    JSP—内置对象
    集合框架—常用的map集合
    集合框架—HashMap
    集合框架—代码—用各种集合进行排序
    集合框架—2种排序比较器
    array
  • 原文地址:https://www.cnblogs.com/mhy12345/p/4001131.html
Copyright © 2011-2022 走看看