G. Sum of Prefix Sums(点分治+李超树)
题意:
一棵树上每个点有权值,对于树上一条路径u->v,将经过的点依次写下成一个序列(a_1,a_2,...a_r),而(sum_i=sum_{j=1}^i a_j),(f(u,v)=sum_{i=1}^r sum_i),求所有路径中最大的f(u,v)。
解法:
考虑所有树上路径,很容易可以想到点分治。但对于每条插入的链,如果作为后半段的话,(f=presum*len+extraval)。对于每条插入的链,我们选择之前哪条链作为前端最大值其实相当于f=k*x+b求x=len时的最大值,而k和b对于不同的前半段都是不同的。这个形式我们可以想到李超树。对于每个新插入的链相当于插入一条线段,对于每个询问线段查找x=len时的最大值即可。
#include <bits/stdc++.h>
#define lson rt << 1
#define rson rt << 1 | 1
#define ll long long
using namespace std;
const ll inf = 1e9;
const int maxn = 150000;
ll ans = 0;
vector <int> edge[maxn + 11];
ll a[maxn + 11];
int dep[maxn + 11],f[maxn + 11],siz[maxn + 11];
bool vis[maxn + 11] = {false};
int root,all,minn;
int n;
struct Line{
ll k,b;
Line() { k = b = -inf; }
Line(ll k,ll b) { this -> k = k; this -> b = b; }
ll at(int x) { return k * x + b; }
};
struct LiChao{
Line tree[4 * maxn + 11];
int n;
void build(int rt,int l,int r) {
tree[rt] = Line();
if (l == r) return;
int mid = (l + r) >> 1;
build(lson , l , mid);
build(rson , mid + 1 , r);
}
void clear(int _) { n = _; build(1 , 1 , _); }
void update(int rt,int l,int r,Line line) {
int mid = (l + r) >> 1;
if (tree[rt].at(mid) < line.at(mid)) swap(tree[rt] , line);
if (l == r) return;
if (tree[rt].at(l) < line.at(l)) update(lson , l , mid , line);
else update(rson , mid + 1 , r , line);
}
void modify(int rt,int l,int r,int al,int ar,Line line) {
if (l > ar || r < al) return;
if (l >= al && r <= ar) {
update(rt , l , r , line);
return;
}
int mid = (l + r) >> 1;
modify(lson , l , mid , al , ar , line);
modify(rson , mid + 1 , r , al , ar , line);
}
void add(int l,int r,ll k,ll b) { // 在[l,r]插入一条y=kx+b线段
modify(1 , 1 , n , l , r , Line{k , b});
}
ll query(int rt,int l,int r,int x) {
if (l == r) return tree[rt].at(x);
int mid = (l + r) >> 1;
if (mid >= x) return max(tree[rt].at(x) , query(lson , l , mid , x));
return max(tree[rt].at(x) , query(rson , mid + 1 , r , x));
}
ll query(int x) {//查找x处的最大y
return query(1 , 1 , n , x);
}
}Lc;
void getroot(int x,int fa) {
siz[x] = 1; f[x] = 0;
for (auto v : edge[x]) {
if (v == fa || vis[v]) continue;
getroot(v , x);
siz[x] += siz[v];
if (siz[v] > f[x]) f[x] = siz[v];
}
f[x] = max(f[x] , all - siz[x]);
if (f[x] < minn) { minn = f[x]; root = x; }
}
void setdep(int x,int fa) {
for (auto v : edge[x]) {
if (vis[v] || v == fa) continue;
dep[v] = dep[x] + 1;
setdep(v , x);
}
}
ll query(int x,int fa,ll sum,ll val) {
ll ans = Lc.query(dep[x]) + val;
for (auto v : edge[x]) {
if (v == fa || vis[v]) continue;
ans = max(ans , query(v , x , sum + a[v] , val + sum + a[v]));
}
return ans;
}
void add(int x,int fa,ll sum,ll val) {
ans = max(ans , val); Lc.add(1 , n , sum , val);
for (auto v : edge[x]) {
if (vis[v] || v == fa) continue;
add(v , x , sum + a[v] , val + a[v] * (dep[v] + 1));
}
}
void divide(int x) {
vis[x] = true;
if (minn == 0) { ans = max(ans , a[x]); return; }
Lc.clear(minn);
Lc.add(1 , minn , a[x] , a[x]);
for (auto v : edge[x]) {
if (vis[v]) continue;
ans = max(ans , query(v , x , a[v] , a[v]));
add(v , x , a[x] + a[v] , a[x] + a[v] * 2);
}
reverse(edge[x].begin() , edge[x].end());
Lc.clear(minn);
Lc.add(1 , minn , a[x] , a[x]);
for (auto v : edge[x]) {
if (vis[v]) continue;
ans = max(ans , query(v , x , a[v] , a[v]));
add(v , x , a[x] + a[v] , a[x] + a[v] * 2);
}
for (auto v : edge[x]) {
if (vis[v]) continue;
all = minn = siz[v]; root = 0;
getroot(v , x); dep[root] = 0; setdep(root , x);
divide(root);
}
}
int main(){
scanf("%d" , &n);
for (int i = 1; i < n; i++) {
int u,v;
scanf("%d %d",&u,&v);
edge[u].push_back(v);
edge[v].push_back(u);
}
for (int i = 1; i <= n; i++) scanf("%lld" , &a[i]);
all = minn = n; root = 0;
getroot(1 , 0); dep[root] = 0; setdep(root , 0);
divide(root);
printf("%lld
" , ans);
}