BZOJ 1444:[JSOI2009]有趣的游戏
首先我们建出Trie图,然后高斯消元。
我们设(f_i)表示经过第(i)个点的期望次数:
[f_x=sum icdot p_x(i)
]
(p_x(i))表示经过第(x)个点(i)次的概率。我们设表示一个单词的节点为关键节点,则所有关键节点只会经过一次,也就是说(f_{关键}=p_{关键}(1)),也就是我们要求的答案。
[displaystyle f_x=sum_{y与x相连}rate_{yRightarrow x}f_y
]
特别地(displaystyle f_1=sum_{y与1相连}rate_{yRightarrow 1}f_y+1),因为初始点在(1)。
(rate_{yRightarrow x})就是能从(y)走到(x)的字母的出现概率。
根据这些等式列方程,再高斯消元就行了。
代码:
#include<bits/stdc++.h>
#define ll long long
#define N 12
#define eps 1e-7
using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}
int n,l,m;
double rate[26];
double w[N*N][N*N];
char str[N];
namespace AC_automation {
int cnt=1;
int id[N];
struct trie {
int ch[26];
int w,fail;
}tr[N*N];
void Insert(char *s,int No) {
int len=strlen(s+1),now=1;
for(int i=1;i<=len;i++) {
int j=s[i]-'A';
if(!tr[now].ch[j]) tr[now].ch[j]=++cnt;
now=tr[now].ch[j];
}
id[No]=now;
tr[now].w=1;
}
queue<int>q;
void build_fail() {
q.push(1);
while(!q.empty()) {
int v=q.front();
q.pop();
for(int i=0;i<26;i++) {
if(!tr[v].ch[i]) continue ;
int sn=tr[v].ch[i],f=tr[v].fail;
while(f&&!tr[f].ch[i]) f=tr[f].fail;
if(!f) tr[sn].fail=1;
else tr[sn].fail=tr[f].ch[i];
q.push(sn);
}
}
}
int find_sn(int now,int j) {
while(now&&!tr[now].ch[j]) now=tr[now].fail;
return now?tr[now].ch[j]:1;
}
void build_matrix() {
for(int i=1;i<=cnt;i++) {
w[i][i]=-1;
if(tr[i].w) continue ;
else {
for(int j=0;j<m;j++) {
int sn=find_sn(i,j);
w[sn][i]+=rate[j];
}
}
}
w[1][cnt+1]=-1;
}
}
int sum;
double ans[N*N];
void Gauss(int n) {
for(int i=1;i<=n;i++) {
for(int j=i+1;j<=n;j++) {
if(fabs(w[i][i])<fabs(w[j][i])) swap(w[i],w[j]);
if(fabs(w[i][i])<eps) continue ;
for(int j=i+1;j<=n;j++) {
double tem=w[j][i]/w[i][i];
for(int k=i;k<=n+1;k++) w[j][k]-=tem*w[i][k];
}
}
}
for(int i=n;i>=1;i--) {
if(fabs(w[i][i])<eps) {ans[i]=0;continue ;}
for(int j=i+1;j<=n;j++) w[i][n+1]-=w[i][j]*ans[j];
ans[i]=w[i][n+1]/w[i][i];
}
}
int main() {
n=Get(),l=Get(),m=Get();
double a,b;
for(int i=0;i<m;i++) {
a=Get(),b=Get();
rate[i]=a/b;
}
for(int i=1;i<=n;i++) {
scanf("%s",str+1);
AC_automation::Insert(str,i);
}
AC_automation::build_fail();
AC_automation::build_matrix();
sum=AC_automation::cnt;
Gauss(sum);
for(int i=1;i<=n;i++) {
double a=ans[AC_automation::id[i]];
if(fabs(a)>0.005) cout<<fixed<<setprecision(2)<<a<<"
";
else cout<<"0.00"<<"
";
}
return 0;
}