试题来源
清华大学2011年百名信息学优秀高中学子夏令营
问题描述
有人打算送给你一条宝石项链,包含了N颗五颜六色(一共有M种颜色)的宝石。因为本问题中你只关心每个宝石的颜色,而且项链现在两头还没有接在一起,它可以被看成是一个数字串。
你希望在五颜六色的宝石中看到连续的一段同色宝石。因此,你定义一根宝石项链的幸运度是它最长的由同色宝石构成的连续子串的长度。 比如,项链112322211的幸运度是3,因为它包括了由同色宝石构成的子串222。而首尾的两个11并不构成连续1111,因为这个项链现在是串形的而不是环形的。
然而,你还没有见到这个项链。你只知道每个宝石是每种颜色的概率。并且,已知每个宝石的颜色分布是独立的。现在你希望在真的见到这条项链之前计算一下,这条项链的幸运度的期望是多少?
输入格式
输入的第一行有两个正整数N和M。
后面N行每行有M个非负实数。其中第i行第j列的数P_(i,j)含义是第i个宝石是颜色j的概率是P_(i,j)。每行的M个实数保证和为1。
输出格式
一个实数,即这条项链的幸运度的期望。四舍五入至小数点后6位。
样例输入
4 2
1.0 0.0
0.5 0.5
0.0 1.0
0.5 0.5
样例输出
2.250000
样例说明
我们用1和2来分别表示两种颜色的宝石,则这串项链有四种等概率的情形:1121,1122,1221和1222。它们的幸运度分别是2,2,2,3,因此期望的幸运度是2.25。
数据规模和约定
30%的数据满足N≤16,M≤3。
60%的数据满足N≤100。
100%的数据满足2≤N≤1000,2≤M≤10。
题解
首先,先想到60分的dp,设 (f[i][j][k][t]) 代表到第 (i) 位,最长连续段长为 (j) ,末尾连续段长为 (k) ,末尾一位颜色为 (t) 的概率,然后暴力转移
这是60分程序:
#include<bits/stdc++.h>
#define ui unsigned int
#define ll long long
#define db double
#define ld long double
#define ull unsigned long long
const int MAXN=100+10,MAXM=10+10;
int n,m;
db ans,f[MAXN][MAXN][MAXN][MAXM],P[MAXN][MAXM];
template<typename T> inline void read(T &x)
{
T data=0,w=1;
char ch=0;
while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
if(ch=='-')w=-1,ch=getchar();
while(ch>='0'&&ch<='9')data=((T)data<<3)+((T)data<<1)+(ch^'0'),ch=getchar();
x=data*w;
}
template<typename T> inline void write(T x,char ch=' ')
{
if(x<0)putchar('-'),x=-x;
if(x>9)write(x/10);
putchar(x%10+'0');
if(ch!=' ')putchar(ch);
}
template<typename T> inline void chkmin(T &x,T y){x=(y<x?y:x);}
template<typename T> inline void chkmax(T &x,T y){x=(y>x?y:x);}
template<typename T> inline T min(T x,T y){return x<y?x:y;}
template<typename T> inline T max(T x,T y){return x>y?x:y;}
int main()
{
read(n);read(m);
for(register int i=1;i<=n;++i)
for(register int j=1;j<=m;++j)scanf("%lf",&P[i][j]);
for(register int i=1;i<=m;++i)f[0][0][0][i]=1.0;
for(register int i=0;i<n;++i)
for(register int j=0;j<=i;++j)
for(register int k=0;k<=j;++k)
for(register int t=1;t<=m;++t)
{
if(k<j)f[i+1][j][k+1][t]+=f[i][j][k][t]*P[i+1][t];
else f[i+1][j+1][k+1][t]+=f[i][j][k][t]*P[i+1][t];
for(register int p=1;p<=m;++p)
if(t==p)continue;
else f[i+1][j][1][p]+=f[i][j][k][t]*P[i+1][p];
}
for(register int i=1;i<=n;++i)
for(register int j=1;j<=i;++j)
for(register int k=1;k<=m;++k)ans+=f[n][i][j][k]*i;
printf("%.6f
",ans);
return 0;
}
然后,压一维,不计某位连续段长,而在转移的时候把新的连续段长超过 (j) 的概率减去,保证已经考虑的位置中的连续段一定不大于 (j)
设 (f[i][j][k]) 表示到第 (i) 位,连续段长不超过 (j) ,末尾颜色为 (k) 的概率,(g[i][j][k]) 表示从第 (i) 位到第 (j) 位,期间颜色全部是 (k) 的概率,(s[i][j]) 为到第 (i) 位,连续段长不超过 (j)的概率
那么
(s[i][j]=sum_{k=1}^mf[i][j][k]) ,这个很好理解
(f[i][j][k]=s[i-1][j]*g[i][i][k]-(s[i-j-1][j]-f[i-j-1][j][k])*g[i-j][i][k])
这个东西有点杂
首先如果不去保证连续段长度一定不大于 (j) ,概率就是 (s[i-1][j]*g[i][i][j]) ,这个东西肯定是要减去不合法情况的,唯一的不合法情况就是加了这个新的数,使得末尾的连续段长度变成了 (j+1) ,也就是 (i-j) 到 (i) 期间全部都是 (k) 颜色。所以就是后面减去的东西。(s[i-j-1][j]) 减去 (f[i-j-1][j][k]) 是因为要保证 (i-j-1) 位上不能是 (k) 颜色,因为如果是 (k) 颜色,那么末尾连续段的长度就不止 (j+1) 了,这样的情况在之前转移的时候已经减过了
AC程序:
#include<bits/stdc++.h>
#define ui unsigned int
#define ll long long
#define db double
#define ld long double
#define ull unsigned long long
const int MAXN=1000+10,MAXM=10+10;
int n,m;
db ans,f[MAXN][MAXN][MAXM],P[MAXN][MAXM],g[MAXN][MAXN][MAXM],s[MAXN][MAXN];
template<typename T> inline void read(T &x)
{
T data=0,w=1;
char ch=0;
while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
if(ch=='-')w=-1,ch=getchar();
while(ch>='0'&&ch<='9')data=((T)data<<3)+((T)data<<1)+(ch^'0'),ch=getchar();
x=data*w;
}
template<typename T> inline void write(T x,char ch=' ')
{
if(x<0)putchar('-'),x=-x;
if(x>9)write(x/10);
putchar(x%10+'0');
if(ch!=' ')putchar(ch);
}
template<typename T> inline void chkmin(T &x,T y){x=(y<x?y:x);}
template<typename T> inline void chkmax(T &x,T y){x=(y>x?y:x);}
template<typename T> inline T min(T x,T y){return x<y?x:y;}
template<typename T> inline T max(T x,T y){return x>y?x:y;}
int main()
{
read(n);read(m);
for(register int i=1;i<=n;++i)
for(register int j=1;j<=m;++j)scanf("%lf",&P[i][j]);
for(register int k=1;k<=m;++k)
for(register int i=1;i<=n;++i)
{
g[i][i][k]=P[i][k];
for(register int j=i+1;j<=n;++j)g[i][j][k]=g[i][j-1][k]*P[j][k];
}
for(register int i=1;i<=n;++i)s[0][i]=1;
for(register int i=1;i<=n;++i)
for(register int j=1;j<=n;++j)
for(register int k=1;k<=m;++k)
{
f[i][j][k]=s[i-1][j]*g[i][i][k];
if(i-j-1>=0)f[i][j][k]-=(s[i-j-1][j]-f[i-j-1][j][k])*g[i-j][i][k];
s[i][j]+=f[i][j][k];
}
for(register int i=1;i<=n;++i)ans+=(s[n][i]-s[n][i-1])*i;
printf("%.6f
",ans);
return 0;
}