问题描述
热情好客的小猴子请森林中的朋友们吃饭,他的朋友被编号为 (1sim N) 每个到来的朋友都会带给他一些礼物:大香蕉。其中,第一个朋友会带给他 1 个大香蕉,之后,每一个朋友到来以后,都会带给他之前所有人带来的礼物个数再加他的编号的 K 次方那么多个。所以,假设 K=2,前几位朋友带来的礼物个数分别是:
(1,5,15,37,83,ldots)
假设 K=3,前几位朋友带来的礼物个数分别是:
(1,9,37,111,ldots)
现在,小猴子好奇自己到底能收到第 N 个朋友多少礼物,因此拜托于你了。
已知 N,K,请输出第 N 个朋友送的礼物个数对 (10^9+7) 取模的结果。
输入格式
第一行,两个整数 N,K。
输出格式
一个整数,表示第 N 个朋友送的礼物个数对 (10^9+7) 取模的结果。
样例输入
1234567890000 3
样例输出
891659731
数据范围
100% 的数据:(N le 10^{18}),(K le 10)。
解析
设 (f_i)表示第 (i) 个朋友送的礼物数量,(sum_i) 表示前 (i) 个朋友送的礼物个数之和。不难得到 (f_i=sum_{i-1}+i^k)。由此,我们可以得到递推式:
[sum_i=2sum_{i-1}+i^k
]
由于 (nle 10^{18}),我们显然要用矩阵快速幂解决这个问题。但 (i^k) 这一项与当前位置有关,不能直接加入转移矩阵中。利用二项式定理稍作转化,我们有 ((i+1)^k=sum_{j=0}^k C_{k}^{j}i^j) 。因此,我们可以得到如下转移矩阵:
[left[
egin{matrix}
2 & C_k^0 & C_k^1 & C_k^2 & ... &C_k^k\
0 & C_k^0 & C_k^1 & C_k^2 & ... &C_k^k \
0 & 0 & C_{k-1}^0 & c_{k-1}^1 & ... &C_{k-1}^{k-1} \
.&.&.&.&.&.\
0&0&0&0&...&C_0^0
end{matrix}
ight] imes
left[
egin{matrix}
s_i\i^k\i^{k-1}\...\i^0
end{matrix}
ight]=
left[
egin{matrix}
s_{i+1}\{(i+1)}^k\{(i+1)}^{k-1}\...\{(i+1)}^0
end{matrix}
ight]
]
矩阵快速幂即可,利用前缀和的性质得到答案。注意n=1或2时需特判。
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#define int long long
#define N 20
using namespace std;
const int mod=1000000007;
struct Matrix{
int n,m,a[N][N];
}ans,s;
Matrix operator * (Matrix a,Matrix b)
{
Matrix c;
c.n=a.n;c.m=b.m;
memset(c.a,0,sizeof(c.a));
for(int i=1;i<=c.n;i++){
for(int j=1;j<=c.m;j++){
for(int k=1;k<=a.m;k++) c.a[i][j]=(c.a[i][j]+a.a[i][k]*b.a[k][j]%mod)%mod;
}
}
return c;
}
int n,k,i,j,c[N][N],sum1,sum2;
Matrix poww(Matrix a,int b)
{
Matrix ans=a,base=a;
b--;
while(b){
if(b&1) ans=ans*base;
base=base*base;
b>>=1;
}
return ans;
}
signed main()
{
scanf("%lld%lld",&n,&k);
if(n==1){
puts("1");
return 0;
}
if(n==2){
int ans=1;
for(i=1;i<=k;i++) ans=ans*2%mod;
printf("%lld
",ans+1);
return 0;
}
c[0][0]=1;
for(i=1;i<=k;i++){
c[i][0]=1;
for(j=1;j<=i;j++) c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
}
s.a[1][1]=2;s.n=s.m=k+2;
for(i=2;i<=k+2;i++) s.a[1][i]=c[k][i-2];
for(i=2;i<=k+2;i++){
for(j=i;j<=k+2;j++) s.a[i][j]=c[k-i+2][j-i];
}
ans.n=k+2;ans.m=1;
for(i=1;i<=k+2;i++) ans.a[i][1]=1;
ans=poww(s,n-2)*ans;
sum1=ans.a[1][1];
ans=s*ans;
sum2=ans.a[1][1];
printf("%lld
",(sum2-sum1+mod)%mod);
return 0;
}