Description
平面上摆放着一个(n*m)的点阵(下图所示是一个3*4的点阵)。Curimit想知道有多少三点组(a,b,c)满足以a,b,c三点共线。这里a,b,c是不同的3个点,其顺序无关紧要。(即(a,b,c)和
(b,c,a)被认为是相同的)。由于答案很大,故你只需要输出答案对1,000,000,007的余数就可以了。
(1le n,mle50000).
Solution
首先是横着的和竖着的点组,总共有(m{n choose 3}+n{mchoose 3})种方案。
接下来是斜着的点组,怎么统计?
我想复杂了,想枚举A点和B点,然后作一条射线,统计射线上有多少个点......
更加简洁的做法是:枚举A点和C点,然后看两点的连线经过多少个点。现在只考虑A在C的左上角的情况,A在C右上角的情况同理,只需要对答案乘2。
我们枚举(C)相对(A)偏移值(x)和(y),若(A(i,j)),则(C(i+x,j+y))。这样的(AC)组合会出现((n-x)(m-y))次。我们发现,(AC)连线经过的点数恰好是(gcd(x,y)-1),因此左上往右下斜的答案就是:
[egin{aligned}
S&=sum_{x=1}^{n-1}sum_{y=1}^{m-1}(n-x)(m-y)(gcd(x,y)-1)\
S+sum_{x=1}^{n-1}sum_{y=1}^{m-1}(n-x)(m-y)&=sum_{x=1}^{n-1}sum_{y=1}^{m-1}(n-x)(m-y)gcd(x,y)\
end{aligned}
]
左边的和式可以直接算,我们关心右边:这里用到了欧拉函数的性质(sum_{d|n}varphi(d)=n):
[egin{aligned}
右边&=sum_{x=1}^{n-1}sum_{y=1}^{m-1}(n-x)(m-y)sum_{d|gcd(x,y)}varphi(d)\
&=sum_{d=1}^{min(n-1,m-1)}varphi(d)sum_{i=1}^{lfloorfrac{n-1}d
floor}(n-i*d)sum_{j=1}^{lfloorfrac{m-1}d
floor}(m-j*d)
end{aligned}
]
枚举(d),用等差数列求和(O(1))算出后面的部分,于是这一段部分也就计算完了!那么总答案就是:
[Ans=m{n choose 3}+n{mchoose 3}+S
]
Code
#include <cstdio>
using namespace std;
const int N=50005,MOD=1e9+7,INV2=500000004;
int n,m;
bool vis[N];
int p[N],pcnt,phi[N];
int fact[N],invfact[N];
inline int min(int x,int y){return x<y?x:y;}
inline int max(int x,int y){return x>y?x:y;}
inline int ksm(int x,int y){
int res=1;
for(;y;x=1LL*x*x%MOD,y>>=1)
if(y&1) res=1LL*res*x%MOD;
return res;
}
void sieve(){
phi[1]=1;
for(int i=2;i<N;i++){
if(!vis[i]){
p[++pcnt]=i;
phi[i]=i-1;
}
for(int j=1;j<=pcnt&&i*p[j]<N;j++){
int x=i*p[j];
vis[x]=true;
if(i%p[j]==0){
phi[x]=phi[i]*p[j];
break;
}
phi[x]=phi[i]*(p[j]-1);
}
}
}
inline int C(int n,int m){
if(m<0||m>n) return 0;
return 1LL*fact[n]*invfact[m]%MOD*invfact[n-m]%MOD;
}
inline int calc(int a,int d){
return (1LL*((a-1)/d)*a%MOD-1LL*INV2*d%MOD*((a-1)/d+1)%MOD*((a-1)/d)%MOD)%MOD;
}
int main(){
sieve();
scanf("%d%d",&n,&m);
int maxs=max(n,m),mins=min(n-1,m-1);
fact[0]=1;
for(int i=1;i<=maxs;i++) fact[i]=1LL*fact[i-1]*i%MOD;
invfact[maxs]=ksm(fact[maxs],MOD-2);
for(int i=maxs-1;i>=0;i--) invfact[i]=1LL*invfact[i+1]*(i+1)%MOD;
int ans=0;
for(int d=1;d<=mins;d++)
(ans+=1LL*phi[d]*calc(n,d)%MOD*calc(m,d)%MOD)%=MOD;
for(int i=1;i<n;i++)
(ans-=1LL*(n-i)*(1LL*m*(m-1)%MOD*INV2%MOD)%MOD)%=MOD;
ans=2LL*ans%MOD;
(ans+=(1LL*n*C(m,3)%MOD+1LL*m*C(n,3)%MOD)%MOD)%MOD;
printf("%d
",ans<0?ans+MOD:ans);
return 0;
}