题目:https://www.luogu.org/problemnew/show/P4220
先写了一下 n^2 和三棵树一样的情况,n^2 还写了ST表O(1)求 lca,其实做 n 遍 dfs 就好了...
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; ll rd() { ll ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return f?ret:-ret; } int const xn=1e5+5; int n,hd[3][xn],ct[3],to[3][xn<<1],nxt[3][xn<<1],in[3][xn],st[3][xn<<1][20],id[3][xn<<1][20],bin[20],bit[xn],tim[3],dep[3][xn]; ll w[3][xn<<1],dis[3][xn]; ll Min(ll x,ll y){return x<y?x:y;} ll Max(ll x,ll y){return x>y?x:y;} void add(int t,int x,int y,ll z){to[t][++ct[t]]=y; nxt[t][ct[t]]=hd[t][x]; hd[t][x]=ct[t]; w[t][ct[t]]=z;} void dfs(int t,int x,int fa) { in[t][x]=++tim[t]; id[t][tim[t]][0]=x; dep[t][x]=dep[t][fa]+1; st[t][tim[t]][0]=dep[t][x]; for(int i=hd[t][x],u;i;i=nxt[t][i]) { if((u=to[t][i])==fa)continue; dis[t][u]=dis[t][x]+w[t][i]; dfs(t,u,x); st[t][++tim[t]][0]=dep[t][x]; id[t][tim[t]][0]=x; } } void work(int t) { for(int i=1;i<20;i++) for(int j=1;j<=tim[t]&&j+bin[i]-1<=tim[t];j++) { if(st[t][j][i-1]<st[t][j+bin[i-1]][i-1]) st[t][j][i]=st[t][j][i-1],id[t][j][i]=id[t][j][i-1]; else st[t][j][i]=st[t][j+bin[i-1]][i-1],id[t][j][i]=id[t][j+bin[i-1]][i-1]; } } int lca(int t,int x,int y) { int l=in[t][x],r=in[t][y]; if(l>r)swap(l,r); int d=bit[r-l+1]; if(st[t][l][d]<st[t][r-bin[d]+1][d])return id[t][l][d]; return id[t][r-bin[d]+1][d]; } ll dist(int t,int x,int y){return dis[t][x]+dis[t][y]-2*dis[t][lca(t,x,y)];} void work1() { for(int t=0;t<3;t++)dfs(t,1,0),work(t); ll ans=0; for(int i=1;i<=n;i++) for(int j=i+1;j<=n;j++) ans=Max(ans,dist(0,i,j)+dist(1,i,j)+dist(2,i,j)); printf("%lld ",ans); } ll ans=0; ll dfsx(int x,int fa) { ll mx=0,nmx=0; for(int i=hd[0][x],u;i;i=nxt[0][i]) { if((u=to[0][i])==fa)continue; ll tmp=dfsx(u,x); if(tmp+w[0][i]>mx)nmx=mx,mx=tmp+w[0][i]; else if(tmp+w[0][i]>nmx)nmx=tmp+w[0][i]; } ans=Max(ans,mx+nmx); return mx; } int main() { n=rd(); ll z; bin[0]=1; for(int i=1;i<20;i++)bin[i]=bin[i-1]*2; bit[1]=0; for(int i=2;i<xn;i++)bit[i]=bit[i>>1]+1;//not n! for(int t=0;t<3;t++) for(int i=1,x,y;i<n;i++) x=rd(),y=rd(),z=rd(),add(t,x,y,z),add(t,y,x,z); if(n<=3000){work1(); return 0;} dfsx(1,0); printf("%lld ",ans*3); return 0; }
然后就用了随机化算法,用 clock() 和 CLOCKS_PER_SEC 卡时间,过了官方数据。
代码如下:
#include<cstdio> #include<cstring> #include<algorithm> #include<ctime> using namespace std; typedef long long ll; ll rd() { ll ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return f?ret:-ret; } ll Min(ll x,ll y){return x<y?x:y;} ll Max(ll x,ll y){return x>y?x:y;} int const xn=1e5+5; int n,hd[3][xn],ct[3],to[3][xn<<1],nxt[3][xn<<1]; ll w[3][xn<<1],dis[3][xn]; void add(int t,int x,int y,ll z){to[t][++ct[t]]=y; nxt[t][ct[t]]=hd[t][x]; hd[t][x]=ct[t]; w[t][ct[t]]=z;} void dfs(int t,int x,int fa) { for(int i=hd[t][x],u;i;i=nxt[t][i]) if((u=to[t][i])!=fa)dis[t][u]=dis[t][x]+w[t][i],dfs(t,u,x); } void work1() { ll ans=0; for(int i=1;i<=n;i++) { dis[0][i]=dis[1][i]=dis[2][i]=0; for(int t=0;t<3;t++)dfs(t,i,0); for(int j=1;j<=n;j++)ans=Max(ans,dis[0][j]+dis[1][j]+dis[2][j]); } printf("%lld ",ans); } bool vis[xn]; int clk(){return (double)clock()/CLOCKS_PER_SEC*1000;} int main() { int st=clk(); n=rd(); ll z; for(int t=0;t<3;t++) for(int i=1,x,y;i<n;i++) x=rd(),y=rd(),z=rd(),add(t,x,y,z),add(t,y,x,z); if(n<=3000){work1(); return 0;} srand(time(0)); srand(rand()); ll ans=0; for(int rt;clk()-st<=3600;) { rt=rand()%n+1; if(vis[rt])continue; vis[rt]=1; int cnt=8; while(cnt--) { vis[rt]=1; for(int t=0;t<3;t++)dis[t][rt]=0,dfs(t,rt,0); ll mx=0; int id=rt; for(int i=1;i<=n;i++) { ll tmp=dis[0][i]+dis[1][i]+dis[2][i]; if(mx<tmp)mx=tmp; if(!vis[i]&&dis[0][id]+dis[1][id]+dis[2][id]<tmp)id=i; } ans=Max(ans,mx); if(id==rt)break; rt=id; } } printf("%lld ",ans); return 0; }