高数题
HJA最近在刷高数题,他遇到了这样一道高数题。这道高数题里面有一棵N个点的树,树上每个点有点权,每条边有颜色。一条路径的权值是这条路径上所有点的点权和,一条合法的路径需要满足该路径上任意相邻的两条边颜色都不相同。问这棵树上所有合法路径的权值和是多少
输入第一行一个整数N,代表树上有多少个点。
接下来一行N个整数,代表树上每个点的权值。
接下来N-1行,每行三个整数S、E、C,代表S与E之间有一条颜色为C的边。输出一行一个整数,代表所求的值。样例输入
6 6 2 3 7 1 4 1 2 1 1 3 2 1 4 3 2 5 1 2 6 2
样例输出
134
提示
对与30%的数据,1≤N≤1000。
对于另外20%的数据,可用的颜色数不超过109且随机数据。
对于另外20%的数据,树的形态为一条链。
对于100%的数据,1≤N≤3*105,可用的颜色数不超过109,所有点权的大小不超过105。
這道題簡單的說是一個樹形DP,由下至上分別統計路徑條數,而考場上我想都沒想就開始寫樹的點分治,當天狀態不佳,加之本身的不熟練,沒能按時把點分治寫出來。
唯一需要注意的是n=3*10^5,dfs足以使程序崩掉,以後最好改寫bfs
#include<iostream> #include<cstdio> #include<cstdlib> #include<cstring> #include<ctime> #include<cmath> #include<algorithm> #include<set> #include<map> #include<vector> #include<string> #include<queue> #include<stack> using namespace std; #ifdef WIN32 #define LL "%I64d" #else #define LL "%lld" #endif #define MAXN 410000 #define MAXV MAXN*2 #define MAXE MAXV*2 #define INF 0x3f3f3f3f #define PROB "gaoshu" typedef long long qword; int nextInt() { char ch; int x=0; while (ch=getchar(),ch < '0' || ch > '9' ); // cout<<ch; do x=x*10+ch-'0'; while (ch=getchar(),ch<='9' && ch>='0'); return x; } int n; struct Edge { int np,col; int val; pair<qword,qword> val2; Edge *next,*neg; int disable; }E[MAXE],*V[MAXV]; qword val[MAXN]; int tope=-1; int bad[MAXN][3]; int size[MAXN]; //int fa[MAXN],depth[MAXN]; //int jump[20][MAXN]; void addedge(int x,int y,int z) { // cout<<"Add:"<<x<<" "<<y<<endl; E[++tope].np=y; E[tope].col=z; E[tope].next=V[x]; V[x]=&E[tope]; E[++tope].np=x; E[tope].col=z; E[tope].next=V[y]; V[y]=&E[tope]; E[tope].neg=&E[tope-1]; E[tope-1].neg=&E[tope]; } /* void dfs1(int now,int d) { size[now]=1;depth[now]=d; Edge *ne; for (ne=V[now];ne;ne=ne->next) { if (ne->np==fa[now])continue; fa[ne->np]=now; dfs1(ne->np,d+1); size[now]+=size[ne->np]; } } void init_lca() { int i,j; for (i=1;i<=n;i++) { jump[0][i]=fa[i]; } for (j=1;j<20;j++) { for (i=1;i<=n;i++) { jump[j][i]=jump[j-1][jump[j-1][i]]; } } } void swim(int &now,int len) { int i=0; while (len) { if (len&1)now=jump[i][now]; i++; } } int lca(int x,int y) { if (depth[x]>depth[y]) { swim(x,depth[x]-depth[y]); }else { swim(y,depth[y]-depth[x]); } int i; if (x==y)return x; for (i=19;i>=0;i--) { if (jump[i][x]!=jump[i][y]) { x=jump[i][x]; y=jump[i][y]; } } return fa[x]; }*/ int bcore,vcore; int size2[MAXN]; int get_core_sizt; int gc_sizf[MAXN]; int gc_mxsz[MAXN]; int gc_col[MAXN]; int gc_now[MAXN]; Edge *nel[MAXN]; int get_core(int dep=1) { gc_sizf[dep]=get_core_sizt-1; gc_mxsz[dep]=0; size2[gc_now[dep]]=1; for (nel[dep]=V[gc_now[dep]];nel[dep];nel[dep]=nel[dep]->next) { if (nel[dep]->np==gc_now[dep-1] || nel[dep]->disable)continue; gc_col[dep+1]=nel[dep]->col; gc_now[dep+1]=nel[dep]->np; get_core(dep+1); size2[gc_now[dep]]+=size2[nel[dep]->np]; gc_sizf[dep]-=size2[nel[dep]->np]; gc_mxsz[dep]=max(gc_mxsz[dep],size2[nel[dep]->np]); } gc_mxsz[dep]=max(gc_mxsz[dep],gc_sizf[dep]); if (gc_mxsz[dep]<vcore) { vcore=gc_mxsz[dep]; bcore=gc_now[dep]; } } pair<qword,qword> dfs2_ret[MAXN],dfs2_tt[MAXN]; int dfs2_now[MAXN],dfs2_col[MAXN]; pair<qword,qword> dfs2(int dep=1) { // pair<qword,qword> ret,tt; dfs2_ret[dep]=make_pair(val[dfs2_now[dep]],1); for (nel[dep]=V[dfs2_now[dep]];nel[dep];nel[dep]=nel[dep]->next) { if (nel[dep]->np==dfs2_now[dep-1] || nel[dep]->disable)continue; dfs2_now[dep+1]=nel[dep]->np; dfs2_col[dep+1]=nel[dep]->col; dfs2_tt[dep]=dfs2(dep+1); dfs2_ret[dep].second+=dfs2_tt[dep].second*(nel[dep]->col!=dfs2_col[dep]); dfs2_ret[dep].first+=(dfs2_tt[dep].first+dfs2_tt[dep].second*val[dfs2_now[dep]]) *(nel[dep]->col!=dfs2_col[dep]); } return dfs2_ret[dep]; } qword solve(int root,int siz) { if (siz==1)return 0; vcore=INF; gc_now[0]=root; gc_col[1]=INF; gc_now[1]=root; get_core_sizt=siz; get_core(1); int core=bcore; gc_now[0]=core; gc_now[1]=core; get_core(1); qword ans=0; Edge *ne; for (ne=V[core];ne;ne=ne->next) { if (ne->disable)continue; ne->disable=core; ne->neg->disable=core; ne->val=size[ne->np]; } for (ne=V[core];ne;ne=ne->next) { if (ne->disable!=core)continue; ans+=solve(ne->np,size2[ne->np]); } for (ne=V[core];ne;ne=ne->next) { if (ne->disable!=core)continue; dfs2_now[0]=dfs2_now[1]=ne->np; dfs2_col[1]=ne->col; ne->val2=dfs2(1); } Edge *ne2; int t=0; map<int,qword> mp; pair<qword,qword> sum=make_pair(0,0); for (ne=V[core];ne;ne=ne->next) { if (ne->disable!=core)continue; sum.first+=ne->val2.first; sum.second+=ne->val2.second; mp[ne->col]+=ne->val2.second; } qword ans2=0; for (ne=V[core];ne;ne=ne->next) { if (ne->disable!=core)continue; ans+=ne->val2.first*(sum.second-mp[ne->col]);//分居兩邊 ans2+=val[core]*ne->val2.second*(sum.second-mp[ne->col]);//分局兩邊,中心貢獻 } ans+=ans2/2; for (ne=V[core];ne;ne=ne->next) { if (ne->disable!=core)continue; t+=ne->val2.second;//中心出發條數 ans+=ne->val2.first;//中心出發,外點貢獻 } ans+=val[core]*t; for (ne=V[core];ne;ne=ne->next) { if (ne->disable==core) ne->disable=ne->neg->disable=0; } return ans; } int main() { //freopen("input.txt","r",stdin); //freopen("output.txt","w",stdout); freopen(PROB".in","r",stdin); freopen(PROB".out","w",stdout); int i,j,k; int x,y,z; //scanf("%d",&n); n=nextInt(); for (i=1;i<=n;i++) val[i]=nextInt();//scanf("%d",&val[i]); for (i=1;i<n;i++) { //scanf("%d%d%d",&x,&y,&z); x=nextInt(); y=nextInt(); z=nextInt(); addedge(x,y,z); } // fa[1]=1; // dfs1(1,0); // init_lca(); qword ans=solve(1,n); printf(LL " ",ans); return 0; }