一句话题意:
给你一个手环,让你往里面填黑色或白色珠子,保证两个黑色不能相邻,问你有多少种填法对 (1e9+7) 取模的结果
题解:
一开始我以为这是一条链,写了写结果发现样例都过不去。
在回来仔细看看题,这 tm 是个环。
看到 (n) 的范围比较大,不用想肯定是矩阵快速幂。
设 (f[i][0/1]) 表示 (i) 这个位置填 白色/黑色 的方案数。
那么又转移:
-
这一位填白色珠子的时候,上一位填黑色白色都行: (f[i][0] = f[i-1][0] + f[i-1][1])
-
这一位填黑色珠子的时候,上一位只能填白色: (f[i][1] = f[i-1][0])
由于它是一个环,所以我们还要枚举一下第一个珠子的颜色。
若第一个珠子颜色为白色,则初始化为 (f[1][0] = 1) ,最后一位可以填黑色或者白色,对答案的贡献就是 (f[n][0]+f[n][1])
若第一个珠子为黑色,则初始化 (f[1][1] = 1) ,最后一位只能填白色,对答案的贡献就是 (f[n][0])
这样时间复杂度是 (O(n)) 的。
考虑矩阵快速幂加速一下。
构建转移矩阵
[left[
egin{matrix}
f[i-1] [0] \
f[i-1] [1] \
end{matrix}
ight]
ag{2} imes
left[
egin{matrix}
1 &1 \
1 & 0 \
end{matrix}
ight] =
left[
egin{matrix}
f[i][0] \
f[i][1] \
end{matrix}
ight]
]
直接莽上矩阵快速幂就完了。
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
#define LL long long
LL T,n,ans;
const int p = 1e9+7;
inline LL read()
{
LL s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
return s * w;
}
struct node
{
LL a[2][2];
}st,tmp;
node operator * (node x,node y)
{
node c;
for(int i = 0; i < 2; i++) for(int j = 0; j < 2; j++) c.a[i][j] = 0;
for(int i = 0; i < 2; i++)
{
for(int j = 0; j < 2; j++)
{
for(int k = 0; k < 2; k++)
{
c.a[i][j] = (c.a[i][j] + x.a[i][k] * y.a[k][j] % p) % p;
}
}
}
return c;
}
node operator + (node x,node y)
{
node c;
for(int i = 0; i < 2; i++)
{
for(int j = 0; j < 2; j++)
{
c.a[i][j] = (x.a[i][j] + y.a[i][j]) % p;
}
}
return c;
}
node ksm(node a,LL b)
{
node res;
for(int i = 0; i < 2; i++)
{
for(int j = 0; j < 2; j++)
{
if(i == j) res.a[i][j] = 1;
else res.a[i][j] = 0;
}
}
for(; b; b >>= 1)
{
if(b & 1) res = res * a;
a = a * a;
}
return res;
}
int main()
{
T = read();
while(T--)
{
n = read(); ans = 0;
st.a[0][0] = 1; st.a[1][0] = st.a[1][1] = st.a[0][1] = 0;
tmp.a[0][0] = tmp.a[0][1] = tmp.a[1][0] = 1; tmp.a[1][1] = 0;
tmp = ksm(tmp,n-1);
st = st * tmp;
ans = (st.a[0][0] + st.a[0][1]) % p;
st.a[0][1] = 1; st.a[0][0] = st.a[1][0] = st.a[1][1] = 0;
st = st * tmp;
ans = (ans + st.a[0][0]) % p;
printf("%lld
",ans);
}
return 0;
}