Description
话说moreD经过不懈努力,终于背完了循环整数,也终于完成了他的蛋糕大餐。
但是不幸的是,moreD得到了诅咒,受到诅咒的原因至今无人知晓。
moreD在发觉自己得到诅咒之后,决定去寻找闻名遐迩的术士CD帮忙。
话说CD最近在搞OI,遇到了一道有趣的题目:
给定两棵树,则总共有NM种方案把这两棵树通过加一条边连成一棵树,那这NM棵树的直径
(树的直径指的是树上的最长简单路径)
大小之和是多少呢?
CD为了考验moreD是否值得自己费心力为他除去诅咒,于是要他编程回答这个问题,但是这m
oreD早就被诅咒搞晕了头脑,就只好请你帮助他了。
Input
第一行两个正整数N,M,分别表示两棵树的大小。
接下来N-1行,每行两个正整数ai,bi,表示第一棵树上的边。
接下来M-1行,每行两个正整数ci,di,表示第二棵树上的边。
N<=105,M<=105,1<=ai,bi<=N,1<=ci,di<=M
Sample Input
4 3
1 2
2 3
2 4
1 3
2 3
Sample Output
53
Solution
这道题主流写法是tree dp, 但我队测时没想到dp, 就口胡了一个算法, 居然A了
首先O(n)求出两颗树各自的直径, 并求出直径的两个端点
因为对于树上任何一点, 离它最远的点一定是直径的两个端点之一
因此用lca求出x与两个端点的距离, 取max得出x与离其最远点的距离len[x]
考虑新树的直径是要么是原先两颗树中较大的直径maxd, 要么是两点相连所形成的新路径tr[0].len[i]+tr[1].len[j]+1
接着将len从小到大排序
那么对于第一颗树中的第 i 个 len, 二分求出第二颗树中的第一个 j,使得tr[0].len[i]+tr[1].len[j]+1>=maxd
那么len[k] (1<=k<j)都满足tr[0].len[i]+tr[1].len[k]+1<maxd
此时新树的直径是原先两颗树中较大的直径maxd
因此贡献为(j-1)*maxd
那么len[k] (j<=k<=m)都满足tr[0].len[i]+tr[1].len[k]+1>=maxd
此时新树的直径是两点相连所形成的新路径tr[0].len[i]+tr[1].len[j]+1
因此贡献为(m-j+1)*tr[0].len[i]+tr[1].sum[m]-tr[1].sum[j-1]+(m-j+1), 其中sum为len的前缀和
所以第一颗树中第 i 个len的贡献为(j-1)*maxd+(m-j+1)*tr[0].len[i]+tr[1].sum[m]-tr[1].sum[j-1]+(m-j+1);
总时间复杂度O(n log n)
#include<bits/stdc++.h>
#define int long long
using namespace std;
int read(){
int x=0,f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1;
for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';
return x*f;
}
const int N=1e5+28;
struct E{int to,nxt;};
struct Tree{
int head[N],cnt,fa[N][20],dep[N],len[N],dl,dr,sum[N];
E l[N<<1];
void Ins(int x,int y){
l[++cnt].nxt=head[x];
l[cnt].to=y;
head[x]=cnt;
}
void Dfs(int x,int f){
fa[x][0]=f;
dep[x]=dep[f]+1;
for(int i=1;i<20;i++)fa[x][i]=fa[fa[x][i-1]][i-1];
for(int i=head[x];i;i=l[i].nxt){
int y=l[i].to;
if(y==f)continue;
Dfs(y,x);
}
}
int Lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=19;i>=0;i--){
int f=fa[x][i];
if(dep[f]>=dep[y])x=f;
}
if(x==y)return x;
for(int i=19;i>=0;i--){
int fx=fa[x][i],fy=fa[y][i];
if(fx!=fy)x=fx,y=fy;
}
return fa[x][0];
}
int Dis(int x,int y){return dep[x]+dep[y]-2*dep[Lca(x,y)];}
int apr[N],dis[N];
queue<int>q;
int Spfa(int s){
memset(apr,0,sizeof(apr));
memset(dis,0x3f,sizeof(dis));
q.push(s);
dis[s]=0;
apr[s]=1;
while(q.size()){
int x=q.front();
q.pop();
apr[x]=0;
for(int i=head[x];i;i=l[i].nxt){
int y=l[i].to;
if(dis[x]+1<dis[y]){
dis[y]=dis[x]+1;
if(!apr[y])q.push(y),apr[y]=1;
}
}
}
}
int Calc(int n){
Dfs(1,0);
dl=dr=1;
for(int i=1;i<=n;i++)if(dep[i]>dep[dl])dl=i;
Spfa(dl);
for(int i=1;i<=n;i++)if(dis[i]>dis[dr])dr=i;
for(int i=1;i<=n;i++)len[i]=max(Dis(i,dl),Dis(i,dr));
sort(len+1,len+n+1);
for(int i=1;i<=n;i++)sum[i]=sum[i-1]+len[i];
return dis[dr];
}
int Match(int n,int x){
int l=1,r=n,re=n+1;
while(l<=r){
int mid=(l+r)>>1;
if(len[mid]>=x)re=mid,r=mid-1;
else l=mid+1;
}
return re;
}
Tree(){
cnt=0;
memset(head,0,sizeof(head));
memset(dep,0,sizeof(dep));
memset(fa,0,sizeof(fa));
memset(len,0,sizeof(len));
memset(sum,0,sizeof(sum));
memset(l,0,sizeof(l));
};
}tr[2];
int n,m,mxd;
signed main(){
// freopen("connect.in","r",stdin);
// freopen("connect.out","w",stdout);
n=read(),m=read();
for(int i=1,x,y;i<n;i++)tr[0].Ins(x=read(),y=read()),tr[0].Ins(y,x);
for(int i=1,x,y;i<m;i++)tr[1].Ins(x=read(),y=read()),tr[1].Ins(y,x);
mxd=max(tr[0].Calc(n),tr[1].Calc(m));
int ans=0;
for(int i=1;i<=n;i++){
int p=tr[1].Match(m,mxd-tr[0].len[i]-1);
ans+=(p-1)*mxd;
ans+=(m-p+1)*tr[0].len[i]+tr[1].sum[m]-tr[1].sum[p-1]+(m-p+1);
}
printf("%lld",ans);
return 0;
}