正题
题目链接:http://www.ybtoj.com.cn/contest/122/problem/3
题目大意
\(S(i)\)表示\(i\)的约数个数,\(Q\)次询问给出\(n,m\)求
\[\sum_{a=1}^n\sum_{b=1}^mS(a^2)\times S(b^2)\times S(a\times b)
\]
\(1\leq Q\leq 10^4,1\leq n,m\leq 2\times 10^5\)
解题思路
前面的推式子挺套路的
首先我们要搞定\(S(n^2)\)这个东西,一个经典的结论就是\(S(n\times m)=\sum_{i|n}\sum_{j|m}[gcd(i,j)=1]\)。莫反一下就有
\[S(a\times b)=\sum_{d|(a\times b)}\mu(d)\sum_{i\times d|a}\sum_{j\times d|b}1
\]
所以就有
\[S(n^2)=\sum_{d|n}\mu(d)S(\frac{n}{d})^2
\]
用线性筛筛出前面的\(S\),然后\(O(n\log n)\)求出\(h(n)=S(n^2)\)
然后化一下式子
\[\sum_{a=1}^n\sum_{b=1}^mh(a)\times h(b)\sum_{i|a}\sum_{j|b}[gcd(i,j)=1]
\]
\[\sum_{d=1}\mu(d)(\sum_{d|i}\sum_{i|a}h(a))(\sum_{d|j}\sum_{j|b}h(b))
\]
\[\sum_{d=1}\mu(d)(\sum_{d|a}S(\frac{a}{d})h(a))(\sum_{d|b}S(\frac{b}{d})h(b))
\]
然后就好像没得化简了,先处理出\(F(d,n)=\sum_{i=1}^nh(i\times d)S(i)\)
发现\(d\)很大的时候后面那个东西的取值就很小,但是\(d\)很多,需要快速处理。
设定一个分界值\(T\),每次小于\(T\)的部分我们就暴力用\(F\)数组计算,大于\(T\)的部分我们预处理出一个
\[G(d,i,j)=\sum_{x=T+1}^dF(i)F(j)\mu(d)
\]
然后整除分块计算。
这里的\(k\)取\(N^{\frac{2}{3}}\)会平均一些,时间复杂度\(O(n^{\frac{4}{3}}+Qn^{\frac{2}{3}})\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#define ll long long
using namespace std;
const ll N=2e5+10,P=1<<30;
ll q,n,m,cnt,pri[N],mu[N],S[N],sg[N],g[N],o[N];
vector<int>f[N],d[N];
bool v[N];
void prime(){
mu[1]=sg[1]=1;
for(ll i=2;i<N;i++){
if(!v[i])pri[++cnt]=i,mu[i]=-1,g[i]=2,sg[i]=2;
for(ll j=1;j<=cnt&&i*pri[j]<N;j++){
v[i*pri[j]]=1;
if(i%pri[j]==0){
g[i*pri[j]]=g[i]+1;
sg[i*pri[j]]=sg[i]/g[i]*g[i*pri[j]];
break;
}
mu[i*pri[j]]=-mu[i];g[i*pri[j]]=2;
sg[i*pri[j]]=sg[i]*sg[pri[j]];
}
}
for(ll i=1;i<N;i++)
for(ll j=i;j<N;j+=i)
(S[j]+=sg[j/i]*sg[j/i]*mu[i]%P)%=P;
return;
}
signed main()
{
freopen("math.in","r",stdin);
freopen("math.out","w",stdout);
prime();
scanf("%lld",&q);ll lim=2e5;
ll T=(ll)pow(lim,2.0/3.0)+1;
f[0].resize(lim+1);
for(ll i=1;i<=lim;i++){
f[i].push_back(0);
for(ll j=1;j<=lim/i;j++){
ll tmp=f[i][j-1];
f[i].push_back((tmp+S[i*j]*sg[j])%P);
}
}
d[T].resize((lim/T)*(lim/T)+1);
for(ll i=T+1;i<=lim;i++){
ll p=lim/i;
d[i].resize(p*p+1);
for(ll j=1,sum=0;j<=lim/i;j++)
for(ll k=j;k<=lim/i;k++)
d[i][(j-1)*p+k]=(d[i-1][(j-1)*o[i-1]+k]+f[i][j]*f[i][k]*mu[i])%P;
o[i]=p;
}
while(q--){
scanf("%lld%lld",&n,&m);
if(n>m)swap(n,m);ll ans=0;
for(ll i=1;i<=min(T,n);i++)
(ans+=1ll*f[i][n/i]*f[i][m/i]*mu[i]%P)%=P;
for(ll l=T+1,r;l<=n;l=r+1){
r=min(n/(n/l),m/(m/l));
(ans+=d[r][(n/l-1)*o[r]+m/l]-d[l-1][(n/l-1)*o[l-1]+m/l])%=P;
}
printf("%lld\n",(ans+P)%P);
}
return 0;
}