题目大意
链接:CF712D
给定(a,b,k,t(1leq a,bleq 100,1leq kleq 1000,1leq tleq 100))。
取(2t)次数,每次取数的范围在([-k,k])之间,求满足最终取出的数之和严格大于(a-b)的方案数。
答案对(1e9+7)取模。
题目分析
首先,我们可以想到一个非常暴力的做法,直接DP,时间复杂度(O(kcdot t^2))。
嗯那我们怎么优化呢?
根本不用优化,这个时间复杂度非常优秀,可以AC。QWQ
代码实现:
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<cstdio>
#include<iomanip>
#include<cstdlib>
#define MAXN 0x7fffffff
typedef long long LL;
const int N=1005,mod=1e9+7;
using namespace std;
inline int Getint(){register int x=0,f=1;register char ch=getchar();while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}return x*f;}
int f[205][N*405];
int main(){
int a=Getint(),b=Getint(),k=Getint(),t=Getint()<<1;
int lim=a-b;
f[0][k*t]=1;
for(int i=1;i<=t;i++){
int l=k*t-k*i,r=l,ret=0;
for(int j=l,lim=k*t+k*i;j<=lim;j++){
while(r<=j+k&&r<=lim)ret=(ret+f[i-1][r])%mod,r++;
while(l<j-k)ret=(ret-f[i-1][l]+mod)%mod,l++;
f[i][j]=ret;
}
}
int ans=0;
for(int i=k*t-lim+1;i<=k*t*2;i++)ans=(ans+f[t][i])%mod;
cout<<ans;
return 0;
}
现在,我想说说该怎么让这个时间复杂度变得更加优秀。
(以下(t)均等于输入的(t*2))
我们可以先进行偏移,向右偏移(kt),可以列出
[(1+x+x^2+cdots+x^{2k})^{2t}
]
而最终答案为指数(> kt-(a-b))的项的系数和。
我们只需化简该式即可。
[egin{split}
ans&=(1+x+x^2+cdots+x^{2k})^{t}\
&=(frac {1-x^{2k+1}}{1-x})^{t}\
&=(1+x^{2k+1})^tcdot(frac 1{1-x})^t
end{split}
]
由二项式定理得
[(1+x^{2k+1})^t=sum_{i=0}^{t}inom ti(-1)^ix^{(2k+1)(t-i)}
]
由广义二项式定理得
[(frac 1{1-x})^t=1+inom t{t-1}x+inom {t+1}{t-1}x^2+cdots
]
所以
[ans=(sum_{i=0}^tinom ti(-1)^ix^{(2k+1)(t-i)})cdot(1+inom t{t-1}x+inom{t+1}{t-1}x^2+cdots)
]
其中,只有系数(i)满足(kt-(a-b)<ileq 2kt)的项会对答案产生贡献。
可以预处理出右侧的前缀和(sum),枚举左边的(i),找到可以产生贡献的区间([l,r])。
[ans=sum_{i=0}^tinom ti(-1)^i(sum[r]-sum[l-1])
]
最终时间复杂度(O(kt))。
代码实现
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<cstdio>
#include<iomanip>
#include<cstdlib>
#define MAXN 0x7fffffff
typedef long long LL;
const int N=800005,T=1005,mod=1e9+7;
using namespace std;
inline int Getint(){register int x=0,f=1;register char ch=getchar();while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}return x*f;}
int ksm(int x,int k){
int ret=1;
while(k){
if(k&1)ret=(LL)ret*x%mod;
x=(LL)x*x%mod,k>>=1;
}
return ret;
}
int fac[N],inv[N];
int C(int n,int m){if(n<m)return 0;return (LL)fac[n]*inv[m]%mod*inv[n-m]%mod;}
int sum[N];
int main(){
int a=Getint(),b=Getint(),k=Getint(),t=Getint()<<1;
int lim=k*t-a+b+1;
fac[0]=1;
for(int i=1;i<=410000;i++)fac[i]=(LL)fac[i-1]*i%mod;
inv[410000]=ksm(fac[410000],mod-2);
for(int i=410000-1;~i;i--)inv[i]=(LL)inv[i+1]*(i+1)%mod;
sum[0]=1;for(int i=1;i<=2*k*t;i++)sum[i]=(sum[i-1]+C(t+i-1,t-1))%mod;
int ans=0;
for(int i=0;i<=t;i++){
int nw=(2*k+1)*(t-i),l=max(lim-nw,0),r=2*k*t-nw;
if(l>r||r<0)continue;
if(l>r)swap(l,r);
ans=(ans+(LL)C(t,i)*((i&1)?-1:1)*((sum[r]-(l?sum[l-1]:0)+mod)%mod)%mod+mod)%mod;
}
cout<<ans;
return 0;
}