题意
有一个长度为(n)的(01)串,你可以每次将相邻的(k)个字符合并,得到一个新的字符并获得一定分数。得到的新字符和分数由这(k)个字符确定。你需要求出你能获得的最大分数。
(n leq 300,k leq 8,1 leq w_i leq 10^9)
分析
参照xyz32768的题解。
区间合并让人想到区间 dp ,而 (k≤8) 又让人想到状压 dp 。
我们考虑合二为一。
(f[l][r][S]) 表示将区间 ([l,r]) 内的字符不断合并,最后变成串 (S) 的最大收益。
( (S) 是一个长度为 ((r−l)mod(k−1)+1) 的 (01) 串)
(由于每次合并会减少 (k−1) 个字符,故 (S) 的长度固定)
考虑 (S) 的每个字符,它们都是由原串的一个区间逐渐压缩成的。
故 (S) 的每个字符互相独立,互不影响。
我们就枚举一个 (mid∈[l,r)) ,表示 (S) 的最后一个字符是由原串的区间 ((mid,r]) 压缩成的。
这时候就有一个非常传统的区间 dp 转移了!
以下把 (mg(S,x)) 定义为 ((S<<1)|x) ,即在 (S) 的后面插入 (x) 。 (x∈{0,1}) 。
[f[l][r][mg(S,x)]=max(f[l][r][mg(S,x)],f[l][mid][S]+f[mid+1][r][x])
]
其中 (x∈{0,1}) 。
注意上面针对的是 (|S|=(r−l)mod(k−1)+1<k−1) 的情况。
如果 (|S|=k−1) ,那么 ([l,mid]) 会和 ((mid,r]) 组成一个长度为 (k) 的串,还可以再次合并。
故当 (|S|=k−1) 时:
[f[l][r][c[mg(S,x)]]=max(f[l][r][c[mg(S,x)]],f[l][mid][S]+f[mid+1][r][x]+w[mg(S,x)])
]
同样 (x∈{0,1})
理论复杂度(O(2^k cdot n^3)) ,但实际状态没有那么多。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<stack>
#include<algorithm>
#include<cstring>
#define rg register
#define il inline
#define co const
template<class T>T read()
{
T data=0;
int w=1;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')
w=-1;
ch=getchar();
}
while(isdigit(ch))
{
data=data*10+ch-'0';
ch=getchar();
}
return data*w;
}
template<class T>T read(T&x)
{
return x=read<T>();
}
using namespace std;
typedef long long ll;
co int MAXN=300,MAXK=8;
int n,k;
int a[MAXN];
int c[1<<MAXK],w[1<<MAXK];
ll f[MAXN][MAXN][1<<MAXK];
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
read(n);read(k);
for(int i=0;i<n;++i)
read(a[i]);
for(int i=0;i<(1<<k);++i)
{
read(c[i]);read(w[i]);
}
memset(f,-1,sizeof f);
for(int i=0;i<n;++i)
f[i][i][a[i]] = 0;
for(int i=n-1;i>=0;--i)
for(int j=i+1;j<n;++j)
{
int len = (j - i) % (k - 1) + 1;
if(len > 1)
{
for(int mid = i + len - 2;mid <= j - 1;mid += k - 1)
for(int s = 0;s < (1 << (len - 1));++s)
{
if(f[i][mid][s]==-1)
continue;
if(f[mid+1][j][0]!=-1)
f[i][j][s<<1] = max(f[i][j][s<<1],f[i][mid][s]+f[mid+1][j][0]);
if(f[mid+1][j][1]!=-1)
f[i][j][s<<1|1] = max(f[i][j][s<<1|1],f[i][mid][s]+f[mid+1][j][1]);
}
}
else
{
for(int s = 0;s < (1 << (k - 1));++s)
for(int mid = i + k - 2;mid <= j - 1;mid += k - 1)
{
if(f[i][mid][s]==-1)
continue;
if(f[mid+1][j][0]!=-1)
f[i][j][c[s<<1]]=max(f[i][j][c[s<<1]],f[i][mid][s]+f[mid+1][j][0]+w[s<<1]);
if(f[mid+1][j][1]!=-1)
f[i][j][c[s<<1|1]]=max(f[i][j][c[s<<1|1]],f[i][mid][s]+f[mid+1][j][1]+w[s<<1|1]);
}
}
}
ll ans=-1;
for(int i=0;i<(1<<k);++i)
{
// cerr<<i<<" f="<<f[0][n-1][i]<<endl;
ans=max(ans,f[0][n-1][i]);
}
printf("%lld
",ans);
return 0;
}