https://acm.hdu.edu.cn/showproblem.php?pid=6960
题意:
用3种颜色(红绿蓝)的珠子构成项链,旋转相同看作相同,绿色珠子使用不超过k个。用n颗珠子能组成多少种颜色的项链。
根据Burnside引理,旋转相同即有n种置换,设(H(i))表示旋转i个珠子的不动点个数,则 $$ans=frac{1}{n}sum_{i=0}^{n-1}H(i)$$
称旋转i个珠子的置换方式为置换 (i)
设(G(n,m))表示不考虑旋转相同时,在n颗珠子的圆环上有m颗绿色珠子的方案数,置换 (i) 会构成(gcd(i,n))个长为(frac{n}{gcd(i,n)})的循环,这里证明可以看 https://www.cnblogs.com/TheRoadToTheGold/p/15080694.html
每个循环颜色都是一样的,即相当于求在(gcd(i,n))颗珠子的圆环上,放(lfloor{frac{k*gcd(i,n)}{n}}
floor)颗绿色珠子,即$$H(i)=sum_{j=0}^{lfloor{frac{k*gcd(i,n)}{n}}
floor}G(gcd(i,n),j)$$
那么
令$$F(d)=sum_{j=0}^{lfloorfrac{k * d}{n}
floor}G(d,j)$$
则$$ans=frac{1}{n}sum_{d|n}Phi(lfloorfrac{n}{d}
floor)F(d)$$
如何计算(G(n,m))?
(G(n,m))是不考虑旋转相同时,在n颗珠子的圆环上有m颗绿色珠子的方案数。每两颗绿色珠子之间要么是红蓝红蓝排列,要么是蓝红蓝红排列。即若m颗绿色珠子形成了m段间隙,假设绿色珠子的排列方式有(D(m))种,则$$G(n,m)=D(m)*2^m$$
(D(m))即在n个珠子构成的圆环上选m个不相邻的珠子的方案数
根据组合数学插空法可以以计算。
我们先计算n个珠子排成一排,选m个不相邻的珠子的方案数,即有n-m个珠子,用m个珠子插空,方案数为(C(n-m+1,m))
现在是n个珠子构成一个圆环,如果定住第一颗珠子不选,方案数为(C(n-m,m))。如果定住第一颗珠子选,方案数为(C(n-m-1,m-1)),这个可以先把n-m个珠子构成一个环,把第一颗珠子插进去,这样就只剩下n-m-1个空隙,还要插m-1个珠子
所以$$D(m)=C(n-m,m)+C(n-m-1,m-1)$$
#include<bits/stdc++.h>
using namespace std;
#define N 1000001
const int mod=998244353;
int p[N],tot;
int phi[N];
bool vis[N];
int n,k;
int pw[N],fac[N],invf[N],inv[N];
void pre()
{
phi[1]=1;
for(int i=2;i<N;++i)
{
if(!vis[i])
{
p[++tot]=i;
phi[i]=i-1;
}
for(int j=1;j<=tot && i*p[j]<N;++j)
{
vis[i*p[j]]=true;
if(!(i%p[j]))
{
phi[i*p[j]]=phi[i]*p[j];
break;
}
else phi[i*p[j]]=phi[i]*(p[j]-1);
}
}
inv[1]=1;
for(int i=2;i<N;i++) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
pw[0]=1;
fac[0]=1;
invf[0]=1;
for(int i=1;i<N;++i)
{
pw[i]=1ll*pw[i-1]*2%mod;
fac[i]=1ll*fac[i-1]*i%mod;
invf[i]=1ll*invf[i-1]*inv[i]%mod;
}
}
int getC(int d,int m)
{
if(d<m) return 0;
return 1ll*fac[d]*invf[m]%mod*invf[d-m]%mod;
}
int getG(int d,int m)
{
if(!m)
{
if(d&1) return 0;
return 2;
}
return 1ll*pw[m]*((getC(d-m,m)+getC(d-m-1,m-1))%mod)%mod;
}
int getF(int d)
{
int m=1ll*k*d/n,f=0;
for(int j=0;j<=m;++j)
f=(f+getG(d,j))%mod;
return f;
}
int main()
{
pre();
int T,ans;
scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&k);
ans=0;
for(int i=1;i*i<=n;++i)
{
if(n%i) continue;
ans=(ans+1ll*phi[n/i]*getF(i)%mod)%mod;
if(n/i!=i) ans=(ans+1ll*phi[i]*getF(n/i)%mod)%mod;
}
ans=1ll*ans*inv[n]%mod;
printf("%d
",ans);
}
}