算法一:点分治+线段树
分析
说是线段树,但是其实要写树状数组卡常。
代码
#include <bits/stdc++.h>
#define rin(i,a,b) for(register int i=(a);i<=(b);++i)
#define irin(i,a,b) for(register int i=(a);i>=(b);--i)
#define trav(i,a) for(register int i=head[a];i;i=e[i].nxt)
#define lowbit(x) ((x)&(-(x)))
typedef long long LL;
using std::cin;
using std::cout;
using std::endl;
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
const int MAXN=50005;
int n,ecnt,cog,totsiz,maxv,now,top,head[MAXN],siz[MAXN],sta[MAXN];
int loc,ql,qr,val[MAXN],maxn[66005],tag[66005],kk;
LL ans;
bool vis[MAXN];
struct Edge{
int to,nxt;
}e[MAXN<<1];
inline void add_edge(int bg,int ed){
++ecnt;
e[ecnt].to=ed;
e[ecnt].nxt=head[bg];
head[bg]=ecnt;
}
//#define mid ((l+r)>>1)
//#define lc (o<<1)
//#define rc ((o<<1)|1)
//
//inline void pushdown(int o){
// if(!tag[o]) return;
// maxn[lc]=maxn[rc]=0;
// tag[lc]=tag[rc]=true;
// tag[o]=false;
//}
//
//void upd(int o,int l,int r){
// if(l==r){
// maxn[o]=std::max(maxn[o],kk);
// return;
// }
// pushdown(o);
// if(loc<=mid) upd(lc,l,mid);
// else upd(rc,mid+1,r);
// maxn[o]=std::max(maxn[lc],maxn[rc]);
//}
//
//int query(int o,int l,int r){
// if(ql<=l&&r<=qr) return maxn[o];
// pushdown(o);
// LL ret=0;
// if(mid>=ql) ret=std::max(ret,query(lc,l,mid));
// if(mid<qr) ret=std::max(ret,query(rc,mid+1,r));
// return ret;
//}
//
//#undef mid
//#undef lc
//#undef rc
void upd(int x,int k){
for(register int i=x;i<=maxv+1;i+=lowbit(i)){
if(tag[i]==now) maxn[i]=std::max(maxn[i],k);
else maxn[i]=k,tag[i]=now;
}
}
void upd(int o,int l,int r){
upd(maxv-loc+1,kk);
}
int query(int x){
int ret=0;
for(register int i=x;i;i-=lowbit(i)) if(tag[i]==now) ret=std::max(ret,maxn[i]);
return ret;
}
int query(int o,int l,int r){
return query(maxv-ql+1);
}
void dfs1(int x,int pre,int len,int minv){
siz[x]=1;
ql=minv,qr=maxv,ans=std::max(ans,1ll*(len+query(1,0,maxv)-1)*minv);
trav(i,x){
int ver=e[i].to;
if(ver==pre||vis[ver]) continue;
dfs1(ver,x,len+1,std::min(minv,val[ver]));
siz[x]+=siz[ver];
}
}
void dfs2(int x,int pre,int len,int minv){
loc=minv,kk=len,upd(1,0,maxv);
trav(i,x){
int ver=e[i].to;
if(ver==pre||vis[ver]) continue;
dfs2(ver,x,len+1,std::min(minv,val[ver]));
}
}
void getcog(int x,int pre){
siz[x]=1;bool flag=true;
trav(i,x){
int ver=e[i].to;
if(ver==pre||vis[ver]) continue;
getcog(ver,x);
siz[x]+=siz[ver];
if(siz[ver]>totsiz/2) flag=false;
}
if(totsiz-siz[x]>totsiz/2) flag=false;
if(flag) cog=x;
}
void solve(int x){
vis[x]=true;top=0;
ans=std::max(ans,1ll*val[x]);
++now;
loc=val[x],kk=1,upd(1,0,maxv);
trav(i,x){
int ver=e[i].to;
if(vis[ver]) continue;
sta[++top]=ver;
dfs1(ver,x,2,std::min(val[x],val[ver]));
dfs2(ver,x,2,std::min(val[x],val[ver]));
}
// maxn[1]=0,tag[1]=true;
++now;
loc=val[x],kk=1,upd(1,0,maxv);
while(top){
int ver=sta[top];
dfs1(ver,x,2,std::min(val[x],val[ver]));
dfs2(ver,x,2,std::min(val[x],val[ver]));
top--;
}
// maxn[1]=0,tag[1]=true;
trav(i,x){
int ver=e[i].to;
if(vis[ver]) continue;
totsiz=siz[ver];
getcog(ver,x);
solve(cog);
}
}
int main(){
n=read();
rin(i,1,n) val[i]=read(),maxv=std::max(maxv,val[i]);
rin(i,2,n){
int u=read(),v=read();
add_edge(u,v);
add_edge(v,u);
}
totsiz=n;
getcog(1,0);
solve(cog);
printf("%lld
",ans);
return 0;
}
算法二:点分治
分析
边分治转点分治(不用学边分治了)(大雾)。
代码
#include <bits/stdc++.h>
#define rin(i,a,b) for(register int i=(a);i<=(b);++i)
#define irin(i,a,b) for(register int i=(a);i>=(b);--i)
#define trav(i,a) for(register int i=head[a];i;i=e[i].nxt)
typedef long long LL;
using std::cin;
using std::cout;
using std::endl;
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
const int MAXN=50005;
int n,ecnt,cog,totsiz,now,head[MAXN],val[MAXN],siz[MAXN],vis[MAXN];
int len1,len2;LL ans;
struct Edge{
int to,nxt;
}e[MAXN<<1];
struct info{
int dis,minv;
inline friend bool operator < (info x,info y){
return x.minv==y.minv?x.dis<y.dis:x.minv>y.minv;
}
}arr1[MAXN],arr2[MAXN];
inline void add_edge(int bg,int ed){
++ecnt;
e[ecnt].to=ed;
e[ecnt].nxt=head[bg];
head[bg]=ecnt;
}
void getcog(int x,int pre){
siz[x]=1;bool flag=true;
trav(i,x){
int ver=e[i].to;
if(ver==pre||vis[ver]) continue;
getcog(ver,x);
siz[x]+=siz[ver];
if(siz[ver]>totsiz/2) flag=false;
}
if(totsiz-siz[x]>totsiz/2) flag=false;
if(flag) cog=x;
}
void dfs(int x,int pre,int dis,int minv,info *arr,int &len){
arr[++len]=(info){dis,minv};
trav(i,x){
int ver=e[i].to;
if(ver==pre||vis[ver]) continue;
dfs(ver,x,dis+1,std::min(minv,val[ver]),arr,len);
}
}
void solve(int x){
getcog(x,0);
if(totsiz==2){
int y=0;
trav(i,x){
int ver=e[i].to;
if(vis[ver]) continue;
y=ver;break;
}
ans=std::max(std::max(ans,std::min(val[x],val[y])*2ll)
,std::max(1ll*val[x],1ll*val[y]));
return;
}
int one=0,thisnow=++now;
trav(i,x){
int ver=e[i].to;
if(vis[ver]) continue;
if(one+siz[ver]>totsiz/2) vis[ver]=now;
else one+=siz[ver];
}
int ano=totsiz-one;++one,len1=len2=0;
dfs(x,0,1,val[x],arr1,len1);
trav(i,x){
int ver=e[i].to;
if(vis[ver]&&vis[ver]!=thisnow) continue;
vis[ver]^=thisnow;
}
dfs(x,0,1,val[x],arr2,len2);
std::sort(arr1+1,arr1+len1+1);
std::sort(arr2+1,arr2+len2+1);
int ptr=0,premax=-1;
rin(i,1,len1){
while(ptr<len2&&arr2[ptr+1].minv>=arr1[i].minv) premax=std::max(premax,arr2[++ptr].dis);
ans=std::max(ans,1ll*(arr1[i].dis+premax-1)*arr1[i].minv);
}
ptr=0,premax=-1;
rin(i,1,len2){
while(ptr<len1&&arr1[ptr+1].minv>=arr2[i].minv) premax=std::max(premax,arr1[++ptr].dis);
ans=std::max(ans,1ll*(arr2[i].dis+premax-1)*arr2[i].minv);
}
totsiz=ano;
getcog(x,0);
solve(cog);
vis[x]=0;
trav(i,x){
int ver=e[i].to;
if(vis[ver]&&vis[ver]!=thisnow) continue;
vis[ver]^=thisnow;
}
totsiz=one;
getcog(x,0);
solve(cog);
}
int main(){
n=read();
rin(i,1,n) val[i]=read();
rin(i,2,n){
int u=read(),v=read();
add_edge(u,v);
add_edge(v,u);
}
totsiz=n;
getcog(1,0);
solve(cog);
printf("%lld
",ans);
}