Noip 2014 senior 复赛 联合权值(link)
【问题描述】
无向连通图G有n个点,n-1条边。点从1到n依次编号,编号为i的点的权值为Wi,每条边的长度均为 1。图上两点(u, v)的距离定义为u点到v点的最短距离。对于图 G 上的点对(u, v),若它们的距离为2,则它们之间会产生的联合权值。
请问图G上所有可产生联合权值的有序点对中,联合权值最大的是多少?所有联合权值之和是多少?
【输入】
输入文件名为 link.in。
第一行包含1个整数 n。
接下来n-1行,每行包含2个用空格隔开的正整数u、v,表示编号为u和编号为v的点之间有边相连。
最后1行,包含n个正整数,每两个正整数之间用一个空格隔开,其中第i个整数表示图 G 上编号为 i 的点的权值为 Wi。
【输出】
输出文件名为 link.out。
输出共 1 行,包含 2 个整数, 之间用一个空格隔开, 依次为图 G 上联合权值的最大值
和所有联合权值之和。 由于所有联合权值之和可能很大,输出它时要对 10007 取余。
这道题由于模型是在树上,所以我们想到使用树形DP。定义f(i)表示以i为根的子树中的点对权值乘积之和,g(i)表示以i为根的子树中的点对权值乘积的最大值,sums1(i)表示i结点的儿子的权值之和,sums2(i)表示i结点的儿子的权值平方之和,sumg(i)表示i结点的孙子的权值之和,max1(i)表示结点i儿子权值的最大值,max2(i)表示结点i儿子权值的次大值,max3(i)表示结点i孙子权值的最大值。则有:
所以求解这个问题就比较清晰了。
还有就是在一个较优算法代码中(枚举,因为没有透彻理解我就不做过多解释了,只给出代码)发现了使用连续4次异或等于( ^= ),通过李兄的详细讲解知道这是swap的意思,只不过不是很稳定。
算了,还是给出枚举的大致思路吧。
维护max1(i),max2(i)分别表示与i相连结点最大权值和次大权值,sum1(i)表示与i相连结点权值之和,sum2(i)表示与i相连结点权值平方和。任意长度为2的点对(i,j)有一个中间结点k,中间结点为k的最大权值之积为max1(k)*max2(k),与k相连结点为x1,x2...xm,则权值乘积之和 W_x1*(W_x2+W_x3+…+W_xm)+W_x2*(W_x1+W_x3+…+W_xm)+…+W_xm*(W_x1+W_x2+…+W_xm-1)=sum1(k)*sum1(k)-sum2(k)枚举k,更新答案即可。时间复杂度为O(n),。
#include<iostream> //dynamic programming #include<cstdio> #include<algorithm> #include<vector> #define pi acos(-1) #define inf 0x7fffffff using namespace std; #define N 200010 #define mod 10007 inline long long read(){ long long data=0,w=1; char ch=0; while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar(); if(ch=='-') w=-1,ch=getchar(); while(ch>='0' && ch<='9') data=data*10+ch-'0',ch=getchar(); return data*w; } inline void write(long long x) { if(x<0) putchar('-'),x=-x; if(x>9) write(x/10); putchar(x%10+'0'); } int n; long long w[N],maxn[N],sum[N]; long long ans_max[N],ans_sum[N]; vector <int> e[N]; inline void Dynamic_Programming(int u,int father) { for(int i=0;i<e[u].size();i++) { int v = e[u][i]; if(v == father) continue; Dynamic_Programming(v,u); maxn[u] = max(maxn[u],w[v]); sum[u] = (sum[u] + w[v]) % mod; ans_max[u] = max(ans_max[u],w[u] * maxn[v]); ans_max[u] = max(ans_max[u],ans_max[v]); ans_sum[u] = (ans_sum[u] + w[u] * sum[v]) % mod; ans_sum[u] = (ans_sum[u] + ans_sum[v]) % mod; } long long aider1 = 0, aider2 = 0; for(int i=0;i<e[u].size();i++) { int v = e[u][i]; if(v == father) continue; ans_max[u] = max(ans_max[u],aider1 * w[v]); ans_sum[u] = (ans_sum[u] + aider2 * w[v]) % mod; aider1 = max(aider1,w[v]); aider2 = (aider2 + w[v]) % mod; } } int main() { n = read(); for(int i=1;i<n;i++) { int u = read(), v = read(); e[u].push_back(v); e[v].push_back(u); } for(int i=1;i<=n;i++) { w[i] = read(); w[i] %= mod; } Dynamic_Programming(1,0); write(ans_max[1]); putchar(' '); write(ans_sum[1] * 2 % mod); return 0; }
#include<iostream> //enumerate #include<cstdio> using namespace std; #define mod 10007 #define N 200010 inline int read(){ int data=0,w=1; char ch=0; while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar(); if(ch=='-') w=-1,ch=getchar(); while(ch>='0' && ch<='9') data=data*10+ch-'0',ch=getchar(); return data*w; } inline void write(int x) { if(x<0) putchar('-'),x=-x; if(x>9) write(x/10); putchar(x%10+'0'); } int n,u,v; int d[N]; struct Edge{ int v,next; }edge[N << 1]; int cnt,head[N]; inline void addedge(int u,int v) { edge[++cnt].v = v; edge[cnt].next = head[u]; head[u] = cnt; } int pre[N],father[N]; int ans_max,ans_sum; int visit[N]; void dfs(int u) { visit[u] = 1; int sum = 0,maxn[3] = {0,0,0}; for(int i=head[u];i;i=edge[i].next) { if(!visit[edge[i].v]) { dfs(edge[i].v); father[edge[i].v] = u; sum = (sum + d[edge[i].v]) % mod; ans_sum = (ans_sum - d[edge[i].v] * d[edge[i].v]) % mod; if(maxn[0] != 2) maxn[0]++; if(maxn[1] < d[edge[i].v]) maxn[1] = d[edge[i].v]; if(maxn[2] < maxn[1]) swap(maxn[1],maxn[2]); } } ans_sum = (ans_sum + sum * sum) % mod; if(maxn[0] == 2) ans_max = max(ans_max,maxn[1] * maxn[2]); } void work() { dfs(1); for(int i=1;i<=n;i++) pre[i] = father[father[i]]; for(int i=1;i<=n;i++) { if(pre[i]) { int aider = d[i] * d[pre[i]]; ans_max = max(ans_max,aider); ans_sum = (ans_sum + aider * 2) % mod; } } if(ans_sum < 0) ans_sum += mod; write(ans_max); putchar(' '); write(ans_sum); } int main() { // freopen("link.in","r",stdin); // freopen("link.out","w",stdout); n = read(); // scanf("%d",&n); for(int i=1;i<n;i++) { u = read(); v = read(); // scanf("%d%d",&u,&v); addedge(u,v); addedge(v,u); } for(int i=1;i<=n;i++) /*scanf("%d",&d[i]);*/ d[i] = read(); work(); return 0; }