Solution [JSOI2016]独特的树叶
题目大意:给出两棵树(A,B),(A)增加一个叶节点与(B)同构,求增加节点在(B)中编号
树(hash),树形(dp)
分析:
我们可以这样考虑,假设现在有一个虚拟点,那么我们枚举(A)的每个点,将以它为根的(A)接在这个虚拟点上(等价于增加一个子节点),判断这棵新的树是否与(B)同构,如果同构,那么与新的(A)同构的(B)的根节点就是答案
可以发现,算法瓶颈在于判断树是否重构,我们可以用树(hash)来干这个事情
我选用的(hash)方法是这样子的
(f[u]=1+sum f[v] imes pri[siz[v]])
(f[u])表示(u)这棵子树的(hash)值,(siz[u])表示(u)这棵子树大小,(pri[i])表示第(i)个质数
(pri)欧拉筛可在(O(n))时间内求出
我们要的是以所有节点各自为根时树的(hash)值,如果每次枚举根复杂度(O(n^2))无法承受,因此可以采用树形(dp)的方式求解
设(down[u])为(u)这棵子树的(hash)值
(down[u] = 1 + sum down[v] imes pri[siz[v]])
设(up[u])为整棵树除去(u)这棵子树的(hash)值,那么
(up[u] = ans[faz]-down[u] imes pri[siz[u]])
(ans[u])表示以(u)为根时树的(hash)值
然后(ans[u]=down[u]+up[u] imes pri[siz[1]-siz[u]])
这里用(siz[1])是因为我是以(1)为根进行树形(dp)的,按情况调整
(hash)之后我们将(hash)值和对应的根丢进(set),然后我们以(u)为根将(A)接到虚拟点后的(hash)值:(1 + ans[u] imes pri[n])((n)为(A)节点数),在(set)里面查找即可
复杂度(O(nlogn))
#include <cstdio>
#include <cctype>
#include <queue>
#include <set>
using namespace std;
const int maxn = 1e5 + 100,maxm = 10000000;
typedef unsigned long long ull;
inline int read(){
int x = 0;char c = getchar();
while(!isdigit(c))c = getchar();
while(isdigit(c))x = x * 10 + c - '0',c = getchar();
return x;
}
int vis[maxm],pri[maxm],pri_tot;
inline void init(){
for(int i = 2;i < maxm;i++){
if(!vis[i])pri[++pri_tot] = i;
for(int j = 1;j <= pri_tot;j++){
if(i * pri[j] > maxm)break;
vis[i * pri[j]] = 1;
if(!(i % pri[j]))break;
}
}
}
struct Tree{
const int root = 1;
vector<int> G[maxn];
int n,siz[maxn],vis[maxn];
inline void addedge(int from,int to){G[from].push_back(to);}
ull down[maxn],up[maxn],f[maxn];//这里f指上文ans
inline void dfs_down(int u,int faz = 0){
down[u] = siz[u] = 1;
for(int v : G[u]){
if(v == faz)continue;
dfs_down(v,u);
siz[u] += siz[v];
down[u] += down[v] * pri[siz[v]];
}
}
inline void dfs_up(int u,int faz = 0){
up[u] = faz ? (f[faz] - down[u] * pri[siz[u]]) : 0;
f[u] = down[u] + (faz ? up[u] * pri[siz[root] - siz[u]] : 0);
for(int v : G[u]){
if(v == faz)continue;
dfs_up(v,u);
}
}
inline void init(){
dfs_down(root);
dfs_up(root);
}
}A,B;
struct Node{
int u;
ull h;
bool operator < (const Node &rhs)const{
return h < rhs.h;
}
};
set<Node> s;
int main(){
init();
A.n = read();
for(int u,v,i = 1;i < A.n;i++)
u = read(),v = read(),A.addedge(u,v),A.addedge(v,u);
B.n = A.n + 1;
for(int u,v,i = 1;i < B.n;i++)
u = read(),v = read(),B.addedge(u,v),B.addedge(v,u);
A.init();
B.init();
for(int i = 1;i <= B.n;i++)
s.insert(Node{i,B.f[i]});
for(int i = 1;i <= A.n;i++){
auto it = s.find(Node{i,1 + A.f[i] * pri[A.n]});
if(it != s.end()){
printf("%d
",it->u);
break;
}
}
return 0;
}