老了…稍微麻烦一点的树形DP都想不到了。
题目描述
给定一棵树,边权是整数 (c_i) ,找出两条不相交的链(没有公共点),使得链长的乘积最大(链长定义为这条链上所有边的权值之和,如果这条链只有 (1) 个点则链长视为 (0))。
输入输出格式
输入格式:
第一行:一个 (n) 表示节点个数。
接下来 (n-1) 行每行三个整数 (u,v,c) 表示 (u,v) 之间有一条 (c) 的边。输出格式:
输出一个整数表示最大的乘积。
输入输出样例
输入样例:
5 1 2 8 2 3 -4 3 4 9 2 5 9
输出样例:
153
说明
(2le nle 4 imes 10^5,|c_i|le 10^9)
51nod 支持
int128
,用法:
__int128 n;
边权可能为负。
题解:
题目要求找出两条互不相交的链,使得两条链长乘积最大。
考虑到负负得正,因此需要分别找最长的两条链和最短的两条链。
trick:对于乘积最大可能产生贡献的分别是最大的两个值和最小的两个值的乘积。
考虑 DP 求直径的过程,对于两条不相交的链,我们同样可以用树形 DP,需要在每个点合并并统计答案。
首先 DFS 一遍,求出 (f[x],g[x]),同时,用带下划线的变量存相对最小值。则 (f[x],f\_[x]) 分别表示 (x) 的子树中连接到 (x) 的最长链和最短链(权值最小);(g[x],g\_[x]) 分别表示 (x) 的子树中最长链和最短链。
注意:在分析过程中只讨论最大值。编程时同时更新最小值即可。
然后我们进行“合并”操作的 DP。
考虑当前在 (x) 点,准备进入子树 (A),那么有以下几种情况:
针对上图而言,我们分析 (x o A) 这条边。认为产生答案的两条链不经过这条边,那么两条链就分别在 (A) 和 (Bcup Ccup {x}cup R) 中产生。
由于我们在第一次 DFS 中已经预处理过 (g[r_A]) 了((r_A) 表示子树 (A) 的根),因此我们需要计算另一半集合贡献的答案 (G)。
因此另一半集合产生的答案可能是形如 (f[r_B]+f[r_C]) 这样拼起来的;也有可能是现成的,形如 (g[r_B])。
此时我们需要求除了 (A) 以外的信息。根据贪心我们知道一定要选最大的两个(或最小的两个),但是不能选即将进去的 (A) 。如果我们遍历 (x) 除了 (A) 的子树的话,复杂度又不对了。因此我们需要对每一个 (x) 连出去的边维护前三大的 (f) 和前两大的 (g)。
另外,我发现了一种维护它们的方便做法,常数是 (10) 左右。假设我们用 (s[4]) 存下前三大值,那么每次更新时把要更新的值放在 (s[3]) 的位置,然后执行
std::sort(s,s+4)
,此时前三大值就在 (s[0]sim s[2]) 里了。当然这是一种代码量小的写法,可能会有一些常数更优秀的做法。
但是 (f[fa_x],g[fa_x](fa_xin R)) 等信息是没有处理的,因为它不是子树,那么我们就需要在第二次 dfs 的过程中在线更新 (x) 父亲方向的答案。
每次进入一棵新的子树时,(g[x]) 可以更新为子树外的答案,即上文提到的 (G)。而 (f[x]) 维护的是一条链,那么只能从子树外的 (f) 更新过来。
但是由于 (nle 400 000),容易爆栈,所以需要一点卡空间的姿势。或者从 (frac n2) 开始 dfs 可以解决某些玄学问题。
而且说起来简单,由于最大值最小值都要维护,代码量还是很大的。
时间复杂度 (O(n))。
Code:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define lll __int128
#define maintain() sort(s[x],s[x]+4,cmp),sort(s_[x],s_[x]+4),sort(t[x],t[x]+3,cmp),sort(t_[x],t_[x]+3);
using std::sort;
bool cmp(ll x,ll y){return x>y;}
ll Max(ll x,ll y){return x>y?x:y;}
ll Min(ll x,ll y){return x<y?x:y;}
struct edge
{
int n,nxt,v;
edge(int n,int nxt,int v)
{
this->n=n;
this->nxt=nxt;
this->v=v;
}
edge(){}
}e[800100];
int head[400100],ecnt=-1;
void add(int from,int to,int v)
{
e[++ecnt]=edge(to,head[from],v);
head[from]=ecnt;
e[++ecnt]=edge(from,head[to],v);
head[to]=ecnt;
}
ll f[400100],g[400100];
ll f_[400100],g_[400100];
//第一次 dfs 求出 f,g
//f[x] 表示 x 子树中伸出来的最大值
//g[x] 表示 x 树内的最大值
//_ 表示最小值
void Dfs(int x,int from)
{
for(int i=head[x];~i;i=e[i].nxt)
if(e[i].n!=from)
{
Dfs(e[i].n,x);
ll p=f[e[i].n]+e[i].v,p_=f_[e[i].n]+e[i].v;
g[x]=Max(Max(g[x],f[x]+p),g[e[i].n]);
g_[x]=Min(Min(g_[x],f_[x]+p_),g_[e[i].n]);
f[x]=Max(f[x],p);
f_[x]=Min(f_[x],p_);
}
g[x]=g[x]>f[x]?g[x]:f[x];
g_[x]=g_[x]<f_[x]?g_[x]:f_[x];
}
lll ans=0;
ll s[400100][4],s_[400100][4];//最大两个 f
ll t[400100][3],t_[400100][3];//最大两个 g
ll Gg,Gg_;//维护子树中最大的 g
ll F,G,F_,G_;//相当于 f[fa[x]]
void dfs(int x,int from,int h)//h 是来边
{
ans=ans>(lll)G*g[x]?ans:(lll)G*g[x];
ans=ans>(lll)G_*g_[x]?ans:(lll)G_*g_[x];
//F 更新 G
ll gg=Max(G,F);//在本层递归中存的 G
ll gg_=Min(G_,F_);
F+=h;
F_+=h;
for(int i=0;i<=3;++i)
s[x][i]=s_[x][i]=0;
for(int i=0;i<=2;++i)
t[x][i]=t_[x][i]=0;//清零
s[x][3]=F;//更新父亲方向的信息
s_[x][3]=F_;
t[x][2]=G;
t_[x][2]=G_;
maintain();
ll ff=F,ff_=F_;
for(int i=head[x];~i;i=e[i].nxt)
if(e[i].n!=from)
{
s[x][3]=f[e[i].n]+e[i].v;
s_[x][3]=f_[e[i].n]+e[i].v;
t[x][2]=g[e[i].n];
t_[x][2]=g_[e[i].n];
maintain();
}
for(int i=head[x];~i;i=e[i].nxt)
if(e[i].n!=from)
{
if(g[e[i].n]==t[x][0])//更新子树中 g 的信息
Gg=t[x][1];
else
Gg=t[x][0];
if(g_[e[i].n]==t_[x][0])
Gg_=t_[x][1];
else
Gg_=t_[x][0];
if(f[e[i].n]+e[i].v==s[x][0])//需要排除即将进去的子树
{
G=Max(Max(gg,s[x][1]+s[x][2]),Gg);
F=Max(ff,s[x][1]);
}
else if(f[e[i].n]+e[i].v==s[x][1])
{
G=Max(Max(gg,s[x][0]+s[x][2]),Gg);
F=Max(ff,s[x][0]);
}
else
{
G=Max(Max(gg,s[x][0]+s[x][1]),Gg);
F=Max(ff,s[x][0]);
}
if(f_[e[i].n]+e[i].v==s_[x][0])//还要更新F
{
G_=Min(Min(gg_,s_[x][1]+s_[x][2]),Gg_);
F_=Min(ff_,s_[x][1]);
}
else if(f_[e[i].n]+e[i].v==s_[x][1])
{
G_=Min(Min(gg_,s_[x][0]+s_[x][2]),Gg_);
F_=Min(ff_,s_[x][0]);
}
else
{
G_=Min(Min(gg_,s_[x][0]+s_[x][1]),Gg_);
F_=Min(ff_,s_[x][0]);
}
dfs(e[i].n,x,e[i].v);
}
}
int main()
{
memset(head,-1,sizeof(head));
int n,u,v,w;
scanf("%d",&n);
for(int i=1;i<n;++i)
{
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);
}
Dfs(n>>1,0);
dfs(n>>1,0,0);
if(!ans)//注意用栈输出时为 0 的情况
puts("0");
int stk[100],tp=0;
while(ans)
{
stk[++tp]=ans%10;
ans/=10;
}
while(tp)
printf("%d",stk[tp--]);
return 0;
}