【分析】
先套路地莫比乌斯反演一波:
思路见 对于部分莫比乌斯反演的套路优化
设 (Nleq M)
(displaystyle quadsum_{i=1}^Nsum_{j=1}^M oldsymbol mu^2(gcd(i, j)))
(displaystyle =sum_{i=1}^Nsum_{j=1}^M sum_{dmid iwedge dmid j}(oldsymbol mu^2*oldsymbol mu)(d))
(displaystyle =sum_{d=1}^N (oldsymbol mu^2*oldsymbol mu)(d)(N/d)(M/d))
然后就求解 ((oldsymbol mu^2*oldsymbol mu)(d)) 的前缀和,跑一个整除分块即可
稍微分析一下积性函数 ((oldsymbol mu^2*oldsymbol mu)(d)) 的性质:
(displaystyle (oldsymbol mu^2*oldsymbol mu)(p^k)=sum_{i=0}^k oldsymbol mu^2(p^i)oldsymbol mu(p^{k-i}))
当 (k=1) 时:((oldsymbol mu^2*oldsymbol mu)(p)=oldsymbol mu^2(1)oldsymbol mu(p)+oldsymbol mu^2(p)oldsymbol mu(1)=-1+1=0)
当 (k=2) 时:((oldsymbol mu^2*oldsymbol mu)(p^2)=oldsymbol mu^2(p)oldsymbol mu(p)=-1)
当 (kgeq 3) 时 ((oldsymbol mu^2*oldsymbol mu)(p^3)=0)
即 (displaystyle (oldsymbol mu^2*oldsymbol mu)(p^k)=-[k=2])
然后我就一通分析,推了一个非常复杂的递推式,成功 MLE+TLE
仔细分析一波,若某个数 (n) 是完全平方数,如果 (oldsymbol mu(sqrt n) eq 0) 则 (sqrt n) 中的每个质因子只出线一次
因此,(n) 中的每个质因子只出线两次,此时 (oldsymbol mu(n) eq 0),否则(包括 (n) 不为完全平方数)均为 (0)
同理进一步分析,可以得到 ((oldsymbol mu^2*oldsymbol mu)(n)=egin{cases}oldsymbol mu(sqrt n), n ext{是完全平方数}\\0, n ext{不是完全平方数}end{cases})
所以我们将 (4 imes 10^6) 内的 (mu(n)) 前缀和打出来,然后把它们映射到 ((oldsymbol mu^2*oldsymbol mu)(n^2)) 的前缀和上
每次查询前缀和的时候直接在映射中二分查找不大于自己的第一个 (n) 对应的函数值即可
【代码】
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pii;
typedef double db;
#define fi first
#define se second
const int MOD=998244353, Lim=4e6, MAXN=Lim+10;
ll n, m;
ll p[MAXN], cntprime, Mu[MAXN];
bool vis[MAXN];
vector<pii> sumF;
inline void sieve(){
Mu[1]=1;
for(int i=2;i<=Lim;++i){
if(!vis[i]) vis[i]=1, p[++cntprime]=i, Mu[i]=-1;
for(int j=1;j<=cntprime;++j)
if(p[j]*i>Lim) break;
else{
vis[p[j]*i]=1;
if(i%p[j]==0){
Mu[i*p[j]]=0;
break;
}
Mu[i*p[j]]=-Mu[i];
}
}
for(int i=2;i<=Lim;++i) Mu[i]+=Mu[i-1];
sumF.push_back( pii(-1e18, 0) );
for(ll i=1;i<=Lim;++i) sumF.push_back( pii(i*i, Mu[i]) );
}
inline ll sumf(ll n){
return ( upper_bound(sumF.begin(), sumF.end(), pii(n, 1e18))-1 )->se;
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
sieve();
cin>>n>>m; if(n>m) swap(n, m);
ll ans=0;
for(ll l=1, r;l<=n; l=r+1){
r=min( n/(n/l), m/(m/l) );
ans+=(sumf(r)-sumf(l-1))%MOD*(n/l%MOD)%MOD*(m/l%MOD)%MOD;
ans%=MOD;
}
ans=(ans+MOD)%MOD;
cout<<ans;
cout.flush();
return 0;
}