由polya定理$$Ans=frac{sum_{d|n}f(d)*phi(frac{n}{d})}{n}$$
$f(i)$表示不考虑旋转同构下长度为$i$的环的合法方案数。
$g(i)$表示第$i$位为男的链的方案数。
$h(i)$表示第$1$位和第$i$位都是男的链的方案数。
对于$g$数组可以$O(n)$算出来。
显然$h(i)=g(i-1)$。对于$i<=m,f(i)=2^i$,否则
$f(i)=sum_{j=0}^m (j+1)*h(i-j)$。写出$f(i+1)$的式子,两式做差可以得到$f$数组的递推公式。
$f(i+1)=f(i)+sum_{j=-1}^{m-1}h(i-j)-(m+1)*h(i-m)$。
当循环节$ileq m$的时候,该循环节全都是女会导致$f(i*j)$可能不会包含$f(i)$中的这个循环节。
此时把下标大于$m$的$f$都$+1$,最后答案$-1$
1 #include <bits/stdc++.h> 2 using namespace std; 3 #define MOD 100000007 4 int phi[1010]; 5 inline int Power(int x, int y) { 6 int ret = 1; 7 while(y) { 8 if(y & 1) ret = 1ll * ret * x % MOD; 9 x = 1ll * x * x % MOD; y >>= 1; 10 } 11 return ret; 12 } 13 inline int get_phi(int x) { 14 int ret = 1; 15 for(int i = 2; i * i <= x; ++ i) { 16 if(x % i == 0) { 17 ret *= (i - 1); x /= i; 18 while(x % i == 0) { 19 x /= i; 20 ret *= i; 21 } 22 } 23 } 24 if(x != 1) ret *= (x - 1); 25 return ret; 26 } 27 inline void init() { 28 for(int i = 1; i <= 1000; ++ i) { 29 phi[i] = get_phi(i); 30 } 31 } 32 int f[1010], g[1010], h[1010]; 33 int sum_g[1010], sum_h[1010]; 34 inline void solve() { 35 int n, m; 36 scanf("%d%d", &n, &m); 37 g[0] = 1, sum_g[0] = 1; 38 for(int i = 1; i <= n; ++ i) { 39 g[i] = sum_g[i - 1] - (i - m - 1 > 0 ? sum_g[i - m - 2] : 0); 40 if(g[i] < 0) g[i] += MOD; 41 sum_g[i] = sum_g[i - 1] + g[i]; 42 if(sum_g[i] >= MOD) sum_g[i] -= MOD; 43 } 44 h[0] = 1; sum_h[0] = 1; 45 for(int i = 1; i <= n; ++ i) { 46 h[i] = g[i - 1]; 47 sum_h[i] = sum_h[i - 1] + h[i]; 48 if(sum_h[i] >= MOD) sum_h[i] -= MOD; 49 } 50 f[0] = 1; 51 for(int i = 1; i <= n; ++ i) { 52 if(i <= m) f[i] = f[i - 1] * 2 % MOD; 53 else if(i == m + 1) { 54 f[i] = 2 * f[i - 1] - 1; 55 f[i] %= MOD; 56 } 57 if(i >= m + 1) { 58 f[i + 1] = f[i] + sum_h[i + 1] - sum_h[i - m] - 1ll * (m + 1) * h[i - m] % MOD + MOD + MOD; 59 f[i + 1] %= MOD; 60 } 61 } 62 for(int i = m + 1; i <= n; ++ i) { 63 f[i] ++; 64 if(f[i] >= MOD) f[i] -= MOD; 65 } 66 int ans = 0; 67 for(int i = 1; i * i <= n; ++ i) if(n % i == 0) { 68 ans += 1ll * f[i] * phi[n / i] % MOD; 69 if(ans >= MOD) ans -= MOD; 70 if(i != n / i) { 71 ans += 1ll * f[n / i] * phi[i] % MOD; 72 if(ans >= MOD) ans -= MOD; 73 } 74 } 75 printf("%d ", 1ll * ans * Power(n, MOD - 2) % MOD - (m >= n ? 0 : 1)); 76 } 77 int main() { 78 init(); 79 int T; 80 scanf("%d", &T); 81 while(T --) { 82 solve(); 83 } 84 }