【BZOJ3994】约数个数和(莫比乌斯反演)
题面
求$$sum_{i=1}^nsum_{j=1}^md(ij)$$
多组数据((<=50000组))
(n,m<=50000)
其中(d(x))是(x)的约数个数
题解
orz ZSY 巨佬
根据玄学(我也不知道为什么)的公式
[d(ij)=sum_{x|i}sum_{y|j}[gcd(x,y)==1]
]
所以,所求等于
[sum_{i=1}^nsum_{j=1}^msum_{u|i}sum_{v|j}[gcd(u,v)==1]
]
把枚举因数丢到前面去
[sum_{u=1}^nsum_{v=1}^m[frac{n}{u}][frac{m}{v}][gcd(u,v)==1]
]
(u,v)看起来很不爽
[sum_{i=1}^nsum_{j=1}^m[frac{n}{i}][frac{m}{j}][gcd(i,j)==1]
]
看起来可以莫比乌斯反演一波了
设
[f(x)=sum_{i=1}^nsum_{j=1}^m[frac{n}{i}][frac{m}{j}][gcd(i,j)==x]
]
[g(x)=sum_{x|d}f(d)
]
所以
[g(x)=sum_{i=1}^nsum_{j=1}^m[frac{n}{i}][frac{m}{j}][x|gcd(i,j)]
]
把(x)提出去,忽略(gcd)的影响
[g(x)=sum_{i=1}^{frac{n}{x}}sum_{j=1}^{frac{m}{x}}[frac{n}{ix}][frac{m}{jx}]
]
预处理出(sum_{i=1}^n[frac{n}{i}])的值(g(x))就可以(O(1))算
预处理的方式,请参考一道水题约数研究
你就会知道这个玩意的值就是每个数约数个数的前缀和
因为一个数的约数个数是积性函数,可以线性筛
所以这个可以(O(n))预处理
接下来的东西就比较好算了
所求就是(f(1))
[f(1)=sum_{d=1}^nmu(d)g(d)
]
把(g(i))展开
[f(1)=sum_{d=1}^nmu(d)sum_{i=1}^{frac{n}{d}}sum_{j=1}^{frac{m}{d}}[frac{n}{i}][frac{m}{j}]
]
很明显可以数论分块
所以再预处理一下(mu(i))的前缀和就行了
单词询问的复杂度就是(O(sqrt n))
总体复杂度(O(Tsqrt n))
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<set>
#include<map>
#include<vector>
#include<queue>
using namespace std;
#define MAX 50000
inline int read()
{
int x=0,t=1;char ch=getchar();
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if(ch=='-')t=-1,ch=getchar();
while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
return x*t;
}
int n,m;
bool zs[MAX+1000];
int pri[MAX+1000],tot,mu[MAX+1000],ys[MAX+1000],dd[MAX+1000];
int smu[MAX+1000],sd[MAX+1000];
void pre()
{
zs[1]=true;mu[1]=ys[1]=1;
for(int i=2;i<=MAX;++i)
{
if(!zs[i])pri[++tot]=i,mu[i]=-1,ys[i]=2,dd[i]=1;
for(int j=1;j<=tot&&i*pri[j]<=MAX;++j)
{
zs[i*pri[j]]=true;
if(i%pri[j]==0)
{
mu[i*pri[j]]=0;
ys[i*pri[j]]=ys[i]/(dd[i]+1)*(dd[i]+2);
dd[i*pri[j]]=dd[i]+1;
break;
}
else mu[i*pri[j]]=-mu[i],ys[i*pri[j]]=ys[i]*2,dd[i*pri[j]]=1;
}
}
for(int i=1;i<=MAX;++i)smu[i]=smu[i-1]+mu[i];
for(int i=1;i<=MAX;++i)sd[i]=sd[i-1]+ys[i];
}
int main()
{
pre();
int T=read();
while(T--)
{
n=read();m=read();
if(n>m)swap(n,m);
long long ans=0;
int i=1,j;
while(i<=n)
{
j=min(n/(n/i),m/(m/i));
ans+=1ll*(smu[j]-smu[i-1])*sd[n/i]*sd[m/i];
i=j+1;
}
printf("%lld
",ans);
}
return 0;
}