题解都在论文里了
#include<bits/stdc++.h> using namespace std; #define maxn 10000005 #define ll long long #define mod 20101009 bool vis[maxn]; int sum[maxn],prime[maxn],mm,mu[maxn]; void primes(){ mu[1]=1; for(int i=2;i<maxn;i++){ if(!vis[i]){ prime[++mm]=i; mu[i]=-1; } for(int j=1;j<=mm;j++){ if(i*prime[j]>=maxn)break; vis[i*prime[j]]=1; if(i%prime[j]==0){ mu[i*prime[j]]=0; break; } else mu[i*prime[j]]=-mu[i]; } } for(int i=1;i<maxn;i++) sum[i]=(sum[i-1]+(ll)mu[i]*i%mod*i%mod)%mod; } ll n,m; inline ll Sum(ll n,ll m){ ll res1=((1+n)*n/2)%mod; ll res2=((1+m)*m/2)%mod; return res1*res2%mod; } inline ll F(ll n,ll m){ ll res=0; if(n>m)swap(n,m); for(int l=1,r;l<=n;l=r+1){ r=min(n/(n/l),m/(m/l)); ll tmp=((sum[r]-sum[l-1])%mod+mod)%mod; res=(res+tmp*Sum(n/l,m/l)%mod)%mod; } return res; } int main(){ primes(); scanf("%lld%lld",&n,&m); if(n>m)swap(n,m); ll ans=0; for(int l=1,r;l<=n;l=r+1){ r=min(n/(n/l),m/(m/l)); ll tmp=(ll)(l+r)*(r-l+1)/2%mod; ans=(ans+tmp*F(n/l,m/l)%mod)%mod; } cout<<ans<<endl; }