若 (frac{a}{b}) 为纯循环小数,设 (l) 为其循环节长度,得:
[large frac{a}{b}left(k^l-1
ight) inmathbb Z Rightarrow k^l=1 pmod{b} Rightarrow k perp b
]
所求即为:
[largeegin{aligned}
&fleft(n,m,k
ight)\
=&sum_{i=1}^nsum_{j=1}^mleft[ i perp j
ight]left[ j perp k
ight]\
=&sum_{i=1}^nsum_{j=1}^mleft[ i perp j
ight]sum_{dmid j and dmid k}mu(d)\
=&sum_{dmid k}mu(d)sum_{i=1}^nsum_{j=1}^{leftlfloor frac{m}{d}
ight
floor}left[ i perp jd
ight]\
=&sum_{dmid k}mu(d)sum_{j=1}^{leftlfloor frac{m}{d}
ight
floor}sum_{i=1}^nleft[ i perp j
ight]left[ i perp d
ight]\
=&sum_{dmid k}mu(d)fleft(leftlfloor frac{m}{d}
ight
floor,n,d
ight)\
end{aligned}
]
边界为:
[largeegin{aligned}
fleft(0,m,k
ight)&=fleft(n,0,k
ight)=0\
fleft(n,m,1
ight)&=sum_{i=1}^nsum_{j=1}^mleft[ i perp j
ight]=sum_{i=1}^nmu(d)leftlfloor frac{n}{d}
ight
floorleftlfloor frac{m}{d}
ight
floor
end{aligned}
]
递归计算即可,算边界是用杜教筛。复杂度为 (O(n^{frac{2}{3}}+sqrt n log^2 n))。
#include<bits/stdc++.h>
#define maxn 10000010
#define maxm 600010
using namespace std;
typedef long long ll;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
ll n,m,k,tot,all=10000000;
ll p[maxn],mu[maxn],s[maxn];
bool tag[maxn];
map<int,ll> sum;
struct edge
{
int to,nxt;
edge(int a=0,int b=0)
{
to=a,nxt=b;
}
}e[maxm];
int head[maxm],edge_cnt;
void add(int from,int to)
{
e[++edge_cnt]=edge(to,head[from]),head[from]=edge_cnt;
}
void init()
{
all=min(all,max(n,m)),mu[1]=1;
for(int i=2;i<=all;++i)
{
if(!tag[i]) p[++tot]=i,mu[i]=-1;
for(int j=1;j<=tot;++j)
{
int k=i*p[j];
if(k>all) break;
tag[k]=true;
if(i%p[j]) mu[k]=-mu[i];
else
{
mu[k]=0;
break;
}
}
}
for(int i=1;i<=all;++i) s[i]=s[i-1]+mu[i];
for(int i=1;i<=k;++i)
if(mu[i])
for(int j=i;j<=k;j+=i)
add(j,i);
}
ll S(int n)
{
if(n<=all) return s[n];
if(sum.count(n)) return sum[n];
ll v=1;
for(int l=2,r;l<=n;l=r+1) r=n/(n/l),v-=S(n/l)*(r-l+1);
return sum[n]=v;
}
ll f(int n,int m,int k)
{
if(!n||!m) return 0;
ll v=0;
if(k==1)
{
if(n>m) swap(n,m);
for(int l=1,r;l<=n;l=r+1)
r=min(n/(n/l),m/(m/l)),v+=(S(r)-S(l-1))*(n/l)*(m/l);
}
else
{
for(int i=head[k];i;i=e[i].nxt)
{
int d=e[i].to;
v+=mu[d]*f(m/d,n,d);
}
}
return v;
}
int main()
{
read(n),read(m),read(k),init();
printf("%lld",f(n,m,k));
return 0;
}