D. GCD Counting 因数分解+换根DP
题目大意:
给你一棵树,每一个节点有一个权值,(g(x,y)) 表示 (x) 到 (y) 的一条简单路径的所有点的权值的gcd,(dist(x,y)) 表示 (x) 到 (y) 的一条简单路径上的点的数量。
求最大的 (dist(x,y)) 并且要求 (g(x,y)>1)
题解:
首先观察权值的数据范围 (1<=a_i<=2e5) ,非常小的一个数据范围,然后根据题目是计算路径上所有点的权值的 (gcd) ,(gcd>1) 表示只要判断至少存在一个质数是这条路径上所有的点的约数即可。
从 (gcd) 推到质数,这个是一个很常规的想法。
对于每一个数,首先进行因数分解,求出这个点的所有的质因子,然后按照普通树形dp的方法,求对于质数 (p) 来说最长的一条路径是多少,但是 (2e5) 也有很多的质数,直接枚举肯定 (TLE) ,但是如果我只枚举当前节点的质因子也是可以的,其他都初始化为 0 即可,注意要换根。
我用 (mp) 写的,注意 (mp[x][y]) 和 (mp[x].count(y)) 的区别,前者如果为0,也会把 (y) 放进去,后者则不会。
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
#define inf64 0x3f3f3f3f3f3f3f3f
#define lson (id<<1)
#define rson (id<<1|1)
using namespace std;
const int maxn = 2e5+10;
const int maxm = 1e3 + 10;
typedef long long ll;
int head[maxn],to[maxn<<1],nxt[maxn<<1],tot;
void add(int u,int v){
++tot,to[tot] = v,nxt[tot] = head[u],head[u] = tot;
++tot,to[tot] = u,nxt[tot] = head[v],head[v] = tot;
}
map<int,int>mp[maxn];//mp[u][p] 表示对于u这个节点,质数p的最长距离
map<int,int>dp[maxn];
int a[maxn],isp[maxm],v[maxm],cnt;
void init(){
cnt = 0;
memset(v,0,sizeof(v));
for(int i=2;i<maxm;i++){
if(!v[i]) v[i] = i,isp[cnt++] = i;
for(int j=0;j<cnt;j++){
if(1ll*i*isp[j]>=maxm) break;
v[i*isp[j]] = isp[j];
if(i%isp[j]==0) break;
}
}
}
void solve(int x){
int now = a[x];
for(int i=0;i<cnt;i++){
if(now%isp[i]==0){
mp[x][isp[i]] = 1;
while (now%isp[i]==0) now/=isp[i];
}
if(x==1) break;
}
if(now!=1) mp[x][now] = 1;
}
int ans;
void dfs1(int u,int pre){
for(int i=head[u];i;i=nxt[i]){
int v = to[i];
if(v == pre) continue;
dfs1(v,u);
for(auto x:mp[u]){
if(mp[v].count(x.first)) {
// printf("dfs1:u = %d v = %d
",u,v);
if(mp[v][x.first]+1>mp[u][x.first]){
dp[u][x.first] = mp[u][x.first];
mp[u][x.first] = mp[v][x.first] + 1;
}
else dp[u][x.first] = max(dp[u][x.first],mp[v][x.first]+1);
}
ans = max(ans,mp[u][x.first]);
// printf("dfs1:u = %d v = %d mp[%d][%d]=%d %d
",u,v,u,x.first,mp[u][x.first],dp[u][x.first]);
}
}
}
void dfs2(int u,int pre){
for(int i=head[u];i;i=nxt[i]){
int v = to[i];
if(v == pre) continue;
for(auto x:mp[v]){
if(mp[u].count(x.first)){
if(mp[u][x.first]==mp[v][x.first]+1){
if(dp[u][x.first]+1>mp[v][x.first]){
dp[v][x.first] = mp[v][x.first];
mp[v][x.first] = dp[u][x.first] + 1;
}
else dp[v][x.first] = max(dp[v][x.first],dp[u][x.first]+1);
}
else if(mp[u][x.first]+1>mp[v][x.first]){
dp[v][x.first] = mp[v][x.first];
mp[v][x.first] = mp[u][x.first]+1;
}
else dp[v][x.first] = max(dp[v][x.first],mp[u][x.first]+1);
}
ans = max(ans,mp[v][x.first]);
}
dfs2(v,u);
}
}
int main(){
init();
int n;
scanf("%d",&n);
ans = 0;
for(int i=1;i<=n;i++) {
scanf("%d",&a[i]);
solve(i);
if(a[i]>1) ans = 1;
}
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
}
dfs1(1,0),dfs2(1,0);
printf("%d
",ans);
return 0;
}