树剖做法:先预处理出来轻重链,然后当修改某一个点的时候,只需要修改同一条链中与当前点相关的边(红色边), 而那些黑色边不需要维护,只需要查询的时候暴力搞一下就好了。
这也就是维护当前点和重儿子点的做法。
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
#include <iostream> #include <cstdio> #include <algorithm> #include <unordered_map> #include <vector> #include <map> #include <list> #include <queue> #include <cstring> #include <cstdlib> #include <ctime> #include <cmath> #include <stack> #pragma GCC optimize(3 , "Ofast" , "inline") using namespace std ; typedef long long ll ; const double esp = 1e-6 , pi = acos(-1) ; typedef pair<int , int> PII ; const int N = 2e4 + 10 , INF = 0x3f3f3f3f , mod = 1e9 + 7; int in() { int x = 0 , f = 1 ; char ch = getchar() ; while(!isdigit(ch)) {if(ch == '-') f = -1 ; ch = getchar() ;} while(isdigit(ch)) x = x * 10 + ch - 48 , ch = getchar() ; return x * f ; } int dfn[N] , son[N] , fa[N] , deep[N] , tp[N] , sz[N] , ans[N] ; vector<int> g[N] ; int cnt , a[N] ; void dfs1(int u , int f) { fa[u] = f , sz[u] = 1 , deep[u] = deep[f] + 1; for(auto v : g[u]) { if(v == f) continue ; dfs1(v , u) ; sz[u] += sz[v] ; if(sz[son[u]] < sz[v]) son[u] = v ; } } void dfs2(int u , int top) { tp[u] = top ; dfn[u] = ++ cnt ; if(u != top) ans[dfn[u]] = __gcd(a[u] , a[fa[u]]) ; if(son[u]) dfs2(son[u] , top) ; for(auto v : g[u]) if(v != fa[u] && v != son[u]) dfs2(v , v) ; } int solve(int u , int v , int k) { int res = 0 ; while(tp[u] != tp[v]) // 将u点跳,一直到u和v两点在同一条链上 { if(deep[tp[u]] < deep[tp[v]]) swap(u , v) ; for(int i = dfn[tp[u]] + 1 ;i <= dfn[u] ;i ++ ) // 根据dfs序进行跳 res += ans[i] <= k ; res += __gcd(a[tp[u]] , a[fa[tp[u]]]) <= k ; // 链与链之间贡献 u = fa[tp[u]] ; } if(deep[u] < deep[v]) swap(u , v) ; for(int i = dfn[v] + 1 ;i <= dfn[u] ;i ++ ) // 现在是同一条链,直接从深度小的跳到深度大的点 res += ans[i] <= k ; return res ; } int main() { int n = in() , q = in() ; for(int i = 1; i <= n ;i ++ ) a[i] = in() ; for(int i = 1 ;i < n ;i ++ ) { int u = in() , v = in() ; g[u].push_back(v) , g[v].push_back(u) ; } dfs1(1 , 0) , dfs2(1 , 1) ; while(q --) { int op = in() ; if(op == 1) { int u = in() , k = in() ; a[u] = k ; if(tp[u] != u) ans[dfn[u]] = __gcd(a[u] , a[fa[u]]) ; // 如果当前点不是这个链的顶部 if(son[u]) ans[dfn[son[u]]] = __gcd(a[u] , a[son[u]]) ; // 如果当前点存在重儿子,也相当于不是这个链的链尾 } else { int u = in() , v = in() , k = in() ; cout << solve(u , v , k) << endl ; } } return 0 ; }
还有一个暴力做法,超过一定度的点不予处理,度数小的点直接暴力修改 , 最后查询的时候,对超过一定度的点暴力查询;
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
#include <iostream> #include <cstdio> #include <algorithm> #include <unordered_map> #include <vector> #include <map> #include <list> #include <queue> #include <cstring> #include <cstdlib> #include <ctime> #include <cmath> #include <stack> #pragma GCC optimize(3 , "Ofast" , "inline") using namespace std ; typedef long long ll ; const double esp = 1e-6 , pi = acos(-1) ; typedef pair<int , int> PII ; const int N = 1e5 + 10 , INF = 0x3f3f3f3f , mod = 1e9 + 7; int in() { int x = 0 , f = 1 ; char ch = getchar() ; while(!isdigit(ch)) {if(ch == '-') f = -1 ; ch = getchar() ;} while(isdigit(ch)) x = x * 10 + ch - 48 , ch = getchar() ; return x * f ; } int e[N] , ne[N] , h[N] , a[N] , idx , n , q , fa[N][25] , deep[N] , fe[N] ; void add(int a , int b) { e[idx] = b , ne[idx] = h[a] , h[a] = idx ++ ; } void dfs(int u , int f ) { deep[u] = deep[f] + 1 ; fa[u][0] = f ; for(int i = 1 ; i <= 20 ; i ++ ) fa[u][i] = fa[fa[u][i - 1]][i - 1] ; for(int i = h[u] ; ~i ;i = ne[i]) { int v = e[i] ; if(v == f) continue ; fe[v] = __gcd(a[u] , a[v]) ; dfs(v , u) ; } } int lca(int a , int b) { if(deep[a] < deep[b]) swap(a , b) ; for(int i = 20 ;i >= 0 ;i -- ) if(deep[fa[a][i]] >= deep[b]) a = fa[a][i] ; if(a == b) return a ; for(int i = 20 ;i >= 0 ;i -- ) if(fa[a][i] != fa[b][i]) a = fa[a][i] , b = fa[b][i] ; return fa[a][0] ; } int out[N] , m = 150; int solve(int u , int v , int k) { int ans = 0 ; int pos = lca(u , v) ; while(u != pos) { if(out[u] > m || out[fa[u][0]] > m) ans += (__gcd(a[u] , a[fa[u][0]]) <= k) ; else if(fe[u] <= k) ans ++ ; u = fa[u][0] ; } while(v != pos) { if(out[v] > m || out[fa[v][0]] > m) ans += (__gcd(a[v] , a[fa[v][0]]) <= k) ; else if(fe[v] <= k) ans ++ ; v = fa[v][0] ; } return ans ; } int main() { memset(h , -1 , sizeof h) ; n = in() , q = in() ; for(int i = 1; i <= n ;i ++ ) a[i] = in() ; for(int i = 1 ; i < n ;i ++ ) { int u = in() , v = in() ; add(u , v) , add(v , u) ; out[u] ++ , out[v] ++ ; } dfs(1 , 0) ; int u , x ; while(q --) { int op = in() ; if(op == 1) { u = in() , x = in() ; a[u] = x ; if(out[u] > m) continue ; for(int i = h[u] ; ~i ; i = ne[i]) { int v = e[i] ; if(out[v] > m) continue ; if(fa[u][0] == v) fe[u] = __gcd(a[u] , a[v]) ; else fe[v] = __gcd(a[u] , a[v]) ; } } else { int u = in() , v = in() , k = in() ; int res = solve(u , v , k) ; cout << res << endl ; } } return 0 ; }