题意:
q个询问,每一个询问给出2个数sum,n
1 <= q <= 10^5, 1 <= n <= sum <= 10^5
对于每一个询问,求满足下列条件的数组的方案数
1.数组有n个元素,ai >= 1
2.sigma(ai) = sum
3.gcd(ai) = 1
solution:
这道题的做法类似bzoj2005能量采集
f(d) 表示gcd(ai) = d 的方案数
h(d) 表示d|gcd(ai)的方案数
令ai = bi * d
则有sigma(bi) = sum / n
d | gcd(ai)
还要满足bi >= 1
则显然有h(d) = C(sum / d - 1,n - 1)
h(d) = f(d) + f(2d) + ... + f(d_max)
这里的d满足:
1.d是sum 的约数
2.sum / d >= n
则f(d) = h(d) - sigma(f(j)) ,2d <=j<=sum/n
倒序遍历d
ans = f(1)
由于询问的次数太多,每次询问后,可以把(sum,n)放入map中,记录下来
//File Name: cf439E.cpp //Author: long //Mail: 736726758@qq.com //Created Time: 2016年02月17日 星期三 14时58分16秒 #include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #include <map> #include <cmath> #include <cstdlib> #include <vector> #define LL long long #define pb push_back using namespace std; const int MAXN = 1e5+5; const int MOD = 1e9+7; LL f[MAXN]; LL jie[MAXN]; bool is[MAXN]; vector<int> dive; map< pair<int,int>,int > rem; void init() { jie[0] = 1; for(int i=1;i<MAXN;i++){ jie[i] = jie[i-1] * i % MOD; } rem.clear(); } void get_dive(int sum,int n) { int e = (int)sqrt(sum + 0.0); dive.clear(); int j; for(int i=1;i<=e;i++){ if(sum % i == 0){ if(sum / i >= n) dive.pb(i); j = sum / i; if(j != i && sum / j >= n) dive.pb(j); } } sort(dive.begin(),dive.end()); for(int i=0;i<dive.size();i++){ is[dive[i]] = true; } } LL qp(LL x,LL y) { LL res = 1LL; while(y){ if(y & 1) res = res * x % MOD; x = x * x % MOD; y >>= 1; } return res; } LL comb(int x ,int y) { if(y < 0 || y > x) return 0; if(y == 0 || y == x) return 1; return jie[x] * qp(jie[y] * jie[x-y] % MOD,MOD - 2) % MOD; } void solve(int sum,int n) { map< pair<int,int>,int >::iterator it; it = rem.find(make_pair(sum,n)); if(it != rem.end()){ printf("%d ",(int)(it->second)); return ; } memset(f,0,sizeof f); memset(is,false,sizeof is); get_dive(sum,n); int ma = dive.size(); for(int i=ma-1;i>=0;i--){ int d = dive[i]; f[d] = comb(sum / d - 1,n - 1); for(int j=2*d;j<=dive[ma-1];j+=d){ if(is[j]){ f[d] = ((f[d] - f[j] + MOD) % MOD + MOD) % MOD; } } } printf("%d ",(int)f[1]); rem[make_pair(sum,n)] = f[1]; return ; } int main() { init(); int test; scanf("%d",&test); while(test--){ int sum,n; scanf("%d %d",&sum,&n); solve(sum,n); } return 0; }