题库链接
考虑莫比乌斯, 套上去之后就是变成了统计长度为d的一共有多少路径, 直接长链剖分,
在计蒜客上极度卡常, 卡了一万年才卡过去, 现场好像还有用点分治过去的, 这都能过??
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int N = (int)5e5 + 7;
const int M = 30000;
int n, d, a[N], vis[N], miu[M + 1];
int now_val, now_op, now_cnt, now_col;
int len[N], son[N], dp[N], *id, *f[N];
bool ok[N];
LL ans;
vector<int> V[M + 1];
vector<int> P[M + 1];
int edge_tot, head[N];
struct Edge {
int to, nex;
} e[N << 1];
inline void addEdge(int u, int v) {
e[edge_tot].to = v;
e[edge_tot].nex = head[u];
head[u] = edge_tot++;
}
bool checkNum(int x) {
for(int i = 2; i * i <= x; i++) {
if(x % (i * i) == 0) return false;
}
return true;
}
void prepare() {
ok[1] = true;
for(int i = 2; i <= M; i++) {
if(checkNum(i)) ok[i] = true;
}
for(int i = 1; i <= M; i++) {
for(int j = 1; j * j <= i; j++) {
if(i % j) continue;
if(ok[j]) V[i].push_back(j);
if(j * j != i && ok[i / j]) V[i].push_back(i / j);
}
}
miu[1] = 1;
for(int i = 1; i <= M; i++) {
for(int j = i + i; j <= M; j += i) {
miu[j] -= miu[i];
}
}
}
void gao(int u, int fa) {
now_cnt++;
vis[u] = now_col;
son[u] = len[u] = 0;
for(int j = head[u], v; ~j; j = e[j].nex) {
v = e[j].to;
if(v == fa || a[v] % now_val) continue;
gao(v, u);
if(len[v] > len[u]) {
len[u] = len[v];
son[u] = v;
}
}
len[u]++;
}
void dfs(int u, int fa) {
f[u][0] = 1;
if(son[u]) {
f[son[u]] = f[u] + 1;
dfs(son[u], u);
}
if(d < len[u]) {
ans += now_op * f[u][d];
}
for(int j = head[u], v; ~j; j = e[j].nex) {
v = e[j].to;
if(v == fa || v == son[u] || a[v] % now_val) continue;
f[v] = id; id += len[v];
dfs(v, u);
for(int i = 1; i <= len[v] && i <= d; i++) {
if(d - i < len[u]) ans += 1LL * now_op * f[v][i - 1] * f[u][d - i];
}
for(int i = 1; i <= len[v] && i <= d; i++) {
f[u][i] += f[v][i - 1];
}
}
}
void init() {
ans = edge_tot = 0;
for(int i = 1; i <= M; i++) {
P[i].clear();
}
for(int i = 1; i <= n; i++) {
head[i] = -1;
vis[i] = 0;
}
}
int main() {
prepare();
int cas = 0;
int T;
scanf("%d", &T);
while(T--) {
scanf("%d%d", &n, &d);
init();
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
for(auto &t : V[a[i]]) {
P[t].push_back(i);
}
}
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
addEdge(u, v);
addEdge(v, u);
}
for(int i = 2; i <= M; i++) {
if(!P[i].size()) continue;
now_val = i; now_op = -miu[i]; now_col = i;
for(auto &Rt : P[i]) {
if(vis[Rt] == i) continue;
now_cnt = 0;
gao(Rt, 0);
for(int i = 0; i <= now_cnt; i++) dp[i] = 0;
id = dp;
f[Rt] = id; id += len[Rt];
dfs(Rt, 0);
}
}
printf("Case #%d: %lld
", ++cas, 2 * ans);
}
return 0;
}
/*
*/