(Description)
有(n)个元素,对于每个元素(x_i)最多知道一个形如(x_j < x_i)或(x_j=x_i)的条件,问有多少合法的序列.合法的序列满足每个元素出现一次,任一相邻两元素之间有小于号或或等于号,并且所有条件全部满足,但是对于两个序列,如果只修改相等元素的位置能使得他们一样,则是他们为同一序列,答案队(10^9)取模
(n<=100)
(Solution)
认真读题后可以发现:
对于每个元素(x_i)最多知道一个形如(x_j<x_i)或(x_j=x_i)的条件
若(x_j=x_i)
将((i,j))合并,
若(x_j<x_i),
将(j)连向(i)
最后建立一个超级根节点,将这个点连向所有入读为(0)的点
这样就是一棵树了.
用(f[i][j])表示以(i)为根的子树中的元素能分(j)段,使(j)段每段相等的合法序列数
首先给出转移方程:
[f[x][i]=f'[x][j]*f[v][k]*C_{i-1}^{j-1}*C_{j-1}^{k-i+j}
]
(f'[x])表示在(v)之前的(f)值
v为x的子树
[C_{i-1}^{j-1}*C_{j-1}^{k-i+j}
]
上面式子表示将(j)段和(k)段合并为(i)段的方案数.
至于为什么?
就是(j-1)放在(x)的方案数(*v)剩下的段余B合并的方案数:
(Code)
#include<bits/stdc++.h>
#define int long long
using namespace std;
typedef long long ll;
const int mod=1e9+7;
int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9') f=(c=='-')?-1:1,c=getchar();
while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar();
return x*f;
}
int c[1011][1011],f[1001][1001],siz[10011],dp[10010],head[10010],cnt;
struct node {
int to,next;
}a[10011];
void add(int x,int y){
a[++cnt].to=y,a[cnt].next=head[x],head[x]=cnt;
a[++cnt].to=x,a[cnt].next=head[y],head[y]=cnt;
}
void init(){
for(int i=0;i<=200;i++)
c[i][i]=c[i][0]=1;
for(int i=2;i<=200;i++)
for(int j=1;j<i;j++)
c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
}
void dfs(int x,int fa){
siz[x]=f[x][1]=1;
for(int e=head[x];e;e=a[e].next){
int v=a[e].to;
if(v==fa)
continue;
dfs(v,x);
memset(dp,0,sizeof(dp));
for(int i=1;i<=siz[x]+siz[v];i++)
for(int j=1;j<=siz[x];j++)
for(int k=1;k<=siz[v];k++){
int z=k-i+j;
if(z<0) continue;
(dp[i]+=f[x][j]*f[v][k]%mod*c[i-1][j-1]%mod*c[j-1][z]%mod)%=mod;
}
siz[x]+=siz[v];
for(int i=1;i<=siz[x];i++)
f[x][i]=dp[i];
}
}
int X[100001],Y[10001],pre[10001],flag[10010],bj[10010],vis[10010],fa[10010];
int find(int x){
return (x==pre[x])?x:pre[x]=find(pre[x]);
}
int find1(int x){
return (x==fa[x])?x:fa[x]=find1(fa[x]);
}
void join(int x,int y){
int fx=find(x),fy=find(y);
if(fx!=fy)
pre[fx]=fy,fa[fx]=fy;
}
main(){
init();
int n=read(),m=read(),x,y;
char s;
for(int i=1;i<=n;i++)
pre[i]=i,fa[i]=i;
for(int i=1;i<=m;i++){
X[i]=read(),scanf("%c",&s),Y[i]=read();
if(s=='=')
flag[i]=1,join(X[i],Y[i]);
}
for(int i=1;i<=m;i++)
if(!flag[i]){
int fx=find(X[i]),fy=find(Y[i]);
add(fx,fy),vis[fy]++;
if(find1(X[i])==find1(Y[i]))
puts("0"),exit(0);
fa[find1(X[i])]=find1(Y[i]);
}
for(int i=1;i<=n;i++)
bj[find(i)]=1;
for(int i=1;i<=n;i++)
if(bj[i]&&!vis[i])
add(n+1,i);
dfs(n+1,0);
int ans=0;
for(int i=1;i<=siz[n+1];i++)
ans+=f[n+1][i],ans%=mod;
printf("%lld ",ans);
}