题意:有一棵n个点的有根树,每条边上有一个边权。给定P,从i跳到它的祖先j的费用是距离的平方+P,问所有点中到根节点1的总花费最大值
n<=1e5,p<=1e6,w<=1e2
思路:对于根节点到每个点i的路径上是一个下凸壳,是经典的斜率优化
考虑在dfs时维护这个下凸壳,在斜率优化加入与删除点时记录下时间戳和操作的类型,dfs结束时恢复即可
1 #include<cstdio> 2 #include<cstring> 3 #include<iostream> 4 #include<algorithm> 5 #include<cmath> 6 typedef long long ll; 7 using namespace std; 8 #define N 210000 9 #define oo 10000000 10 #define MOD 1000000007 11 12 struct node 13 { 14 int t,x,y; 15 }stk[N]; 16 17 ll dp[N],s[N],P; 18 int dep[N],head[N],vet[N],nxt[N],len[N],q[N],flag[N],n,top,tot,tim,t,w; 19 20 int add(int a,int b,int c) 21 { 22 nxt[++tot]=head[a]; 23 vet[tot]=b; 24 len[tot]=c; 25 head[a]=tot; 26 } 27 28 ll sqr(ll x) 29 { 30 return x*x; 31 } 32 33 ll calc(int i,int j) 34 { 35 return dp[j]+sqr(s[i]-s[j])+P; 36 } 37 38 int cmp(int x,int y,int z) 39 { 40 ll x1=dp[x]-dp[y]+sqr(s[x])-sqr(s[y]); 41 ll y1=s[x]-s[y]; 42 ll x2=dp[y]-dp[z]+sqr(s[y])-sqr(s[z]); 43 ll y2=s[y]-s[z]; 44 return x1*y2>=x2*y1; 45 } 46 47 void dfs(int u) 48 { 49 tim++; 50 flag[u]=1; 51 if(u==1) 52 { 53 t=1; w=1; dp[u]=-P; q[1]=1; 54 } 55 else 56 { 57 while(t<w&&calc(u,q[t])>=calc(u,q[t+1])) 58 { 59 stk[++top].t=tim; stk[top].x=1; stk[top].y=q[t]; 60 t++; 61 } 62 dp[u]=calc(u,q[t]); 63 while(t<w&&cmp(q[w-1],q[w],u)) 64 { 65 stk[++top].t=tim; stk[top].x=2; stk[top].y=q[w]; 66 w--; 67 } 68 q[++w]=u; 69 stk[++top].t=tim; stk[top].x=3; 70 } 71 72 int tmp=tim; 73 int e=head[u]; 74 while(e) 75 { 76 int v=vet[e]; 77 if(!flag[v]) 78 { 79 s[v]=s[u]+len[e]; 80 dfs(v); 81 } 82 e=nxt[e]; 83 } 84 while(stk[top].t==tmp) 85 { 86 if(stk[top].x==1) q[--t]=stk[top].y; 87 if(stk[top].x==2) q[++w]=stk[top].y; 88 if(stk[top].x==3) w--; 89 top--; 90 } 91 } 92 93 int main() 94 { 95 int cas; 96 scanf("%d",&cas); 97 while(cas--) 98 { 99 int n; 100 scanf("%d%d",&n,&P); 101 s[1]=0; 102 tot=0; 103 for(int i=1;i<=n;i++) head[i]=flag[i]=0; 104 for(int i=1;i<=n-1;i++) 105 { 106 int x,y,z; 107 scanf("%d%d%d",&x,&y,&z); 108 add(x,y,z); 109 add(y,x,z); 110 } 111 tim=0; 112 t=1; w=0; top=0; 113 dfs(1); 114 ll ans=0; 115 for(int i=2;i<=n;i++) ans=max(ans,dp[i]); 116 printf("%I64d ",ans); 117 } 118 return 0; 119 } 120