Solution
可以发现对于每个因子,如果以它作为每块的大小合法,最多也只有一种方案。
如何判断每个因子是否合法??用到了一种非常巧妙的方法,统计每个节点子树的$siz$,如果当前节点$siz$是枚举的因子的倍数,那么意味着在这个节点到它父亲节点必须要割一刀分出联通块。
那么统计出割多少刀就是有多少个联通块。又知道联通块数就是$n/x$,x是我们枚举的块的大小。所以判断是否相同即可。
复杂度就是$n*因子个数$。
Code
#include<bits/stdc++.h> using namespace std; struct Node { int v, nex; } Edge[2000005]; int h[1000005], stot; void add(int u, int v) { Edge[++stot] = (Node) {v, h[u]}; h[u] = stot; } int siz[1000005]; void dfs(int u, int f) { siz[u] = 1; for(int i = h[u]; i; i = Edge[i].nex) { int v = Edge[i].v; if(v == f) continue; dfs(v, u); siz[u] += siz[v]; } } int main() { freopen("count.in", "r", stdin); freopen("count.out", "w", stdout); int n; scanf("%d", &n); for(int i = 1; i < n; i ++) { int u, v; scanf("%d%d", &u, &v); add(u, v); add(v, u); } dfs(1, 0); int ans = 0; for(int i = 1; i <= n; i ++) { if(n % i == 0) { int now = 0; for(int j = 1; j <= n; j ++) if(siz[j] % i == 0) now ++; if(now == n / i) ans ++; } } printf("%d", ans); return 0; }
Solution
暴力就是二分+$n^2$check以每个节点开始能不能跳过整个区间。
二分正确性显然,如何优化check?
可以想到开车旅行,从每个点出发快速跳$mid$步能跳到哪显然可以用倍增预处理出来!
然而这样还是$nmlog^2$的,所以考虑把$m$给优化下去。
再处理一个倍增数组,表示从某个点开始跳$j$个$mid$次能跳到哪,即跳多少轮。这个显然是每次枚举了mid后进入check时处理出来的。然后将$m$二进制拆位,跳即可。
成功优化到$nlog^2$叻!!
【注意!】跳倍增要先减再跳第二次错了555555!!!
Code
#include<bits/stdc++.h> using namespace std; int n, m; int jum[200005][21], val[200005][21], to[200005][21]; int t[200005], mi; void init() { for(int i = 1; i <= 2 * n; i ++) jum[i][0] = i + 1, val[i][0] = t[i]; for(int p = 0; p <= 20; p ++) jum[2 * n + 1][p] = 2 * n + 1; for(int p = 1; p <= 20; p ++) for(int i = 1; i <= 2 * n; i ++) jum[i][p] = jum[jum[i][p - 1]][p - 1], val[i][p] = val[i][p - 1] + val[jum[i][p - 1]][p - 1]; } void jump(int mid) { for(int j = 1; j <= 2 * n; j ++) { int QAQ = mid, u = j; for(int i = 20; i >= 0; i --) if(QAQ >= val[u][i]) QAQ -= val[u][i], u = jum[u][i]; ///////先减再跳!!!! to[j][0] = u; } for(int i = 0; i <= 20; i ++) to[2 * n + 1][i] = 2 * n + 1; for(int p = 1; p <= 20; p ++) for(int i = 1; i <= 2 * n; i ++) to[i][p] = to[to[i][p - 1]][p - 1]; } bool check(int mid) { jump(mid); for(int i = 1; i <= n; i ++) { int u = i, M = m, t = 0; while(M) { if(M & 1) u = to[u][t]; M >>= 1, t ++; } if(u >= i + n) return 1; } return 0; } int erfen() { int l = mi, r = 50000000, ans; while(l <= r) { int mid = (l + r) >> 1; if(check(mid)) ans = mid, r = mid - 1; else l = mid + 1; } return ans; } int main() { freopen("dinner.in", "r", stdin); freopen("dinner.out", "w", stdout); scanf("%d%d", &n, &m); for(int i = 1; i <= n; i ++) scanf("%d", &t[i]), t[i + n] = t[i], mi = max(mi, t[i]); init(); int ans = erfen(); printf("%d", ans); return 0; }