题意
JYY有两棵树A和B:树A有N个点,编号为1到N;树B有N+1个点,编号为1到N+1。JYY知道树B恰好是由树A加上一个叶节点,然后将节点的编号打乱后得到的。他想知道,这个多余的叶子到底是树B中的哪一个叶节点呢?
(1≤N≤10^5)
分析
参照ccz181078的题解。
树形dp求出两棵树每个点的子树和上方子树的hash值,进而可以求出每个点作为根的整棵树(或删去一个叶子后)的hash值
有根树的hash的一个实现方式是递归定义为Hash,例如子树的hash值排序后的字符串hash,因为树的形态和子树顺序无关,所以要排序。
他写的是双哈希,我不想用,结果被卡WA了。最后我发现把cal函数写成这个样子:
ull cal(ull v)
{
return v*20030506+8946;
}
就不会WA。并且那个只要式子不是v*P+P,大概率可以AC,都不用是质数。玄学操作。
时间复杂度(O(N log_2 N))。
代码
出题人卡哈希有点水平呀。
#include<bits/stdc++.h>
#define rg register
#define il inline
#define co const
template<class T>il T read()
{
rg T data=0;
rg int w=1;
rg char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')
w=-1;
ch=getchar();
}
while(isdigit(ch))
{
data=data*10+ch-'0';
ch=getchar();
}
return data*w;
}
template<class T>il T read(rg T&x)
{
return x=read<T>();
}
typedef unsigned long long ull;
using namespace std;
co int N=1e5+10,P=31;
int n,ans=N;
ull bin[N],hash[N],val[N];
bool isin(ull v)
{
return binary_search(val+1,val+n+1,v);
}
ull cal(ull v)
{
return v*20030506+P;
}
struct Graph
{
int n,fa[N];
vector<int>g[N];
ull down[N],up[N];
void getdown(int x,int fa)
{
this->fa[x]=fa;
vector<ull>v;
for(int i=0;i<g[x].size();++i)
{
int y=g[x][i];
if(y==fa) continue;
getdown(y,x);
v.push_back(down[y]);
}
sort(v.begin(),v.end());
for(int i=0;i<v.size();++i)
down[x]=down[x]*P+v[i];
down[x]=cal(down[x]);
}
typedef pair<ull,int> pui;
vector<pui> getv(int x)
{
vector<pui>v;
for(int i=0;i<g[x].size();++i)
{
int y=g[x][i];
if(y==fa[x]) continue;
v.push_back(pui(down[y],y));
}
v.push_back(pui(up[x],fa[x]));
sort(v.begin(),v.end());
return v;
}
void getup(int x)
{
vector<pui>v=getv(x);
for(int i=0;i<v.size();++i)
hash[i+1]=hash[i]*P+v[i].first;
for(int i=0;i<v.size();++i)
{
int y=v[i].second;
if(y==fa[x]) continue;
up[y]=cal(hash[v.size()]+(hash[i]-hash[i+1])*bin[v.size()-1-i]);
}
for(int i=0;i<g[x].size();++i)
{
int y=g[x][i];
if(y==fa[x]) continue;
getup(y);
}
}
void init(int n)
{
this->n=n;
for(int i=1;i<n;++i)
{
int x=read<int>(),y=read<int>();
g[x].push_back(y),g[y].push_back(x);
}
getdown(1,0);
getup(1);
}
void sum()
{
for(int x=1;x<=n;++x)
{
vector<pui>v=getv(x);
for(int i=0;i<v.size();++i)
val[x]=val[x]*P+v[i].first;
val[x]=cal(val[x]);
}
sort(val+1,val+n+1);
}
int query()
{
if(g[1].size()==1&&isin(down[g[1][0]]))
return 1;
for(int x=2;x<=n;++x)
if(g[x].size()==1&&isin(up[x]))
return x;
return -1;
}
}A,B;
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
read(n);
bin[0]=1;
for(int i=1;i<=n;++i)
bin[i]=bin[i-1]*P;
A.init(n),B.init(n+1);
A.sum();
printf("%d
",B.query());
return 0;
}