设 (f_i(s)) 表示 (s) 是否有长度为 (i) 的 ( ext{border}),其取值为 (0) 或 (1),不难得答案为:
[large Eleft( (f_1(s)+f_2(s)+cdots+f_{n-1}(s))^2
ight)=sum_{i=1}^{n-1}sum_{j=1}^{n-1}E(f_i(s)f_j(s))
]
若 (s) 有长度为 (i) 的 ( ext{border}),则其有长度为 (n-i) 的周期,因此将 (f_i(s)) 的定义改为周期,答案的式子不变。当 (s) 有周期 (i,j) 时,考虑将一定相同的位置连边,若得到 (cnt) 个连通块,则有:
[large E(f_i(s)f_j(s))=k^{cnt-n}
]
式子的含义就是连通块第一个点的字符随便选,其他点要和该连通块第一个点的字符相同。
当 (i+j leqslant n) 时,因为有周期 (i),形成了模 (i) 意义下的剩余系,于是只需考虑前 (i) 个点,因为前 (i) 个点都能向后连一条长度为 (j) 的边,因此得到 (gcd(i,j)) 个连通块。当 (i+j>n) 时,只有前 (n-j) 个点能向后连边,每连一条边都有可能使连通块个数减一,但当连边形成环时,连通块个数就不变,得连通块个数为 (i+j-n) 加上环的个数,不难得到环的个数为 (max(n-j-(i-gcd(i,j)),0))。
整理后代入答案的式子得:
[largesum_{i=1}^{n-1}sum_{j=1}^{n-1}k^{max(i+j-n,gcd(i,j))-n}
]
枚举 (i+j) 和 (gcd(i,j)) 得:
[largeegin{aligned}
&sum_{s=2}^{2n-2}sum_{g=1}^{n-1}sum_{i=max(1,s-n+1)}^{min(n-1,s-1)}[gcd(i,s-i)=g]\
=&sum_{s=2}^{2n-2}sum_{gmid s}sum_{i=max(1,frac{s}{g}-leftlfloor frac{n-1}{g}
ight
floor)}^{min(leftlfloorfrac{n-1}{g}
ight
floor,frac{s}{g}-1)}left[gcdleft(i,frac{s}{g}-i
ight)=1
ight]\
end{aligned}
]
设 (l=max(1,frac{s}{g}-leftlfloor frac{n-1}{g} ight floor),r=min(leftlfloorfrac{n-1}{g} ight floor,frac{s}{g}-1)),反演得:
[largeegin{aligned}
&sum_{s=2}^{2n-2}sum_{gmid s}sum_{i=l}^{r}left[gcdleft(i,frac{s}{g}-i
ight)=1
ight]\
=&sum_{s=2}^{2n-2}sum_{gmid s}sum_{d}mu(d)sum_{i=l}^{r}left[dmid i and dmidleft(frac{s}{g}-i
ight)
ight]\
=&sum_{s=2}^{2n-2}sum_{gmid s}sum_{dmid frac{s}{g}}mu(d)sum_{i=l}^{r}left[dmid i
ight]\
=&sum_{s=2}^{2n-2}sum_{gmid s}sum_{dmid frac{s}{g}}mu(d)left(leftlfloorfrac{r}{d}
ight
floor-leftlfloorfrac{l-1}{d}
ight
floor
ight)\
end{aligned}
]
注意到:
[largesum_{imid jmid k leqslant n}1=sum_{i leqslant n}sum_{j mid k leqslant leftlfloorfrac{n}{i}
ight
floor}1=sum_{i leqslant n}leftlfloorfrac{n}{i}
ight
floorlogleft(leftlfloorfrac{n}{i}
ight
floor
ight)
]
因此直接用推得的式子计算的复杂度为 (O(nlog^2n))。
#include<bits/stdc++.h>
#define maxn 200010
#define p 1000000007
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,k,tot;
ll ans;
int pri[maxn];
ll mu[maxn],pw[maxn];
bool tag[maxn];
vector<int> ve[maxn];
ll inv(ll x)
{
ll v=1,y=p-2;
while(y)
{
if(y&1) v=v*x%p;
x=x*x%p,y>>=1;
}
return v;
}
void init(int n)
{
mu[1]=pw[0]=1;
for(int i=1;i<=n;++i) pw[i]=pw[i-1]*k%p;
for(int i=2;i<=n;++i)
{
if(!tag[i]) mu[pri[++tot]=i]=p-1;
for(int j=1;j<=tot;++j)
{
int k=i*pri[j];
if(k>n) break;
tag[k]=true;
if(i%pri[j]) mu[k]=p-mu[i];
else break;
}
}
for(int i=1;i<=n;++i)
for(int j=i;j<=n;j+=i)
ve[j].push_back(i);
}
ll calc(int lim,int sum)
{
if(lim<=0||sum<=1) return 0;
ll l=max(1,sum-lim),r=min(lim,sum-1),v=0;
if(l>r) return 0;
for(int i=0;i<ve[sum].size();++i)
{
int d=ve[sum][i];
v=(v+mu[d]*(r/d-(l-1)/d)%p)%p;
}
return v;
}
int main()
{
read(n),read(k),init(2*n);
for(int s=2;s<=2*n-2;++s)
{
for(int i=0;i<ve[s].size();++i)
{
int g=ve[s][i];
ans=(ans+calc((n-1)/g,s/g)*pw[max(s-n,g)]%p)%p;
}
}
printf("%lld",ans*inv(pw[n])%p);
return 0;
}