题解
这个听起来很毒瘤的想法写起来却非常休闲,理解起来可能很费劲
例如,我们首先到猜到答案是个下凸包
然后是不是要三分???然而并不是orz
我们通过归纳证明这个下凸包的结论来总结出了一个算法
也就是对于每个子树是一个凸包,我们进行两种操作,一种操作是我给这个子树加了一个父亲边
另一种操作是对所有儿子合并这个凸包
容易发现这两种操作都不会影响答案的下凸性
我们考虑操作1
凸包最小值有一段区间是([L,R])
分类讨论加一条边的情况
(f(x) = f(x) + w (x <= L))也就是在(L)之前的要想达到必须全删掉新加的边才是最优的
(f(x) = f(L) + w - (x - L) (L <= x <= L + w))直接从L转移即可
(f(x) = f(L)(L + w <= x <= R + w)) 直接平移就是最优的了
(f(x) = f(R) + w + (x - R))从(R)转移即可
事实上,这是给凸包下面扯开,加了-1,0,1的三段
由此可一归纳出来,这个凸包斜率变化范围不是很大,自底向上是
-1,0,1和-1,0,1两个卷积
然后依次变大
事实上只要维护每次多的两个拐点就可以,给拐点排序后,删掉斜率为正的部分,最后从最靠上的部分开始往下减
代码
#include <bits/stdc++.h>
#define fi first
#define se second
#define pii pair<int,int>
#define pdi pair<db,int>
#define mp make_pair
#define pb push_back
#define enter putchar('
')
#define space putchar(' ')
#define eps 1e-8
#define mo 974711
#define MAXN 300005
//#define ivorysi
using namespace std;
typedef long long int64;
typedef double db;
template<class T>
void read(T &res) {
res = 0;char c = getchar();T f = 1;
while(c < '0' || c > '9') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
res = res * 10 + c - '0';
c = getchar();
}
res *= f;
}
template<class T>
void out(T x) {
if(x < 0) {x = -x;putchar('-');}
if(x >= 10) {
out(x / 10);
}
putchar('0' + x % 10);
}
int N,M;
int fa[MAXN],C[MAXN],d[MAXN];
int64 sum,p[MAXN * 2];
struct node {
int lc,rc,dis;
int64 v;
}tr[MAXN * 2];
int tot,rt[MAXN],cnt;
int Merge(int x,int y) {
if(!x) return y;
if(!y) return x;
if(tr[x].v < tr[y].v) swap(x,y);
tr[x].rc = Merge(tr[x].rc,y);
if(tr[tr[x].rc].dis > tr[tr[x].lc].dis) swap(tr[x].lc,tr[x].rc);
tr[x].dis = tr[tr[x].lc].dis + 1;
return x;
}
int upt(int x) {return Merge(tr[x].lc,tr[x].rc);}
int main() {
#ifdef ivorysi
freopen("f1.in","r",stdin);
#endif
read(N);read(M);
for(int i = 2 ; i <= N + M ; ++i) {
read(fa[i]);read(C[i]);
d[fa[i]]++;
sum += C[i];
}
for(int i = N + M ; i > 1 ; --i) {
int64 l = 0,r = 0;
if(i <= N) {
for(int j = 1 ; j < d[i] ; ++j) rt[i] = upt(rt[i]);
l = tr[rt[i]].v;rt[i] = upt(rt[i]);
r = tr[rt[i]].v;rt[i] = upt(rt[i]);
}
l += C[i];r += C[i];
tr[++tot].v = l;tr[++tot].v = r;
rt[i] = Merge(rt[i],Merge(tot,tot - 1));
rt[fa[i]] = Merge(rt[fa[i]],rt[i]);
}
for(int j = 1 ; j <= d[1]; ++j) rt[1] = upt(rt[1]);
while(rt[1]) {
p[++cnt] = tr[rt[1]].v;
rt[1] = upt(rt[1]);
}
for(int i = cnt ; i >= 1 ; --i) {
sum -= p[i];
}
out(sum);enter;
}