类型:树形 ( ext{DP})
这里主要补充一下 (O(n^3)) 的 ( ext{DP}) 优化的过程,基础转移方程推导可以参考其他巨佬的博客(题解)。
令 (f[x][p]) 表示在以 (x) 为根的子树中,(x) 在拓扑序排在第 (p) 个时的方案数。
转移中设 (x) 在已经合并的拓扑序中排名为 (p_1) ,将要合并的子树(以 (ver) 为根)中 (ver) 排名为 (p_2) ,合并后 (x) 排名为 (p_3) 。
列出转移方程:(为了表示方便略去了取模操作)
- 若 (x < ver) 即 (x) 应在 (ver) 之前,(p_1<p_2) 。
for(p1 = [1,siz[x]])
for(p2 = [1,siz[ver]])
for(p3 = [p1,p1+p2-1])
f[x][p3]+=f[x][p1]*f[ver][p2]*c[p3-1][p1-1]*c[siz[x]+siz[ver]-p3][siz[x]-p1];
- 若 (x > ver) 即 (x) 应在 (ver) 之后,(p_1>p_2) 。
for(p1 = [1,siz[x]])
for(p2 = [1,siz[ver]])
for(p3 = [p1+p2,siz[x]+siz[ver]])
f[x][p3]+=f[x][p1]*f[ver][p2]*c[p3-1][p1-1]*c[siz[x]+siz[ver]-p3][siz[x]-p1];
(更直观的转移方程:)
[f[x][p3]+=f[x][p1] imes f[ver][p2] imes C_{p3-1}^{p1-1} imes C_{siz[x]+siz[ver]-p3}^{siz[x]-p1}
]
到此为止,你可以得到 (40pts) 的好成绩。
交换内外循环
可以观察到 (p_2) 在转移方程中只出现了一次,因此我们将 (p_2) 和 (p_3) 的循环交换,列出转移方程:
- 若 (x < ver) 即 (x) 应在 (ver) 之前,(p_1<p_2) 。
for(p1 = [1,siz[x]])
for(p3 = [p1,p1+siz[ver]-1])
for(p2 = [p3-p1+1,siz[ver]])
f[x][p3]+=f[x][p1]*f[ver][p2]*c[p3-1][p1-1]*c[siz[x]+siz[ver]-p3][siz[x]-p1];
- 若 (x > ver) 即 (x) 应在 (ver) 之后,(p_1>p_2) 。
for(p1 = [1,siz[x]])
for(p3 = [p1+1,p1+siz[ver]])
for(p2 = [1,p3-p1])
f[x][p3]+=f[x][p1]*f[ver][p2]*c[p3-1][p1-1]*c[siz[x]+siz[ver]-p3][siz[x]-p1];
前缀和优化去掉一层循环
通过上一步的交换,发现 (f[x][p3]) 累加中的 (p2) 是一段连续的,并且满足乘法结合律,联想到前缀和优化。
令 (g[x][p]) 表示在以 (x) 为根的子树中,(x) 在拓扑序排在前 (p) 个时的方案数之和。
那么转移方程变为:
- 若 (x < ver) 即 (x) 应在 (ver) 之前,(p_1<p_2) 。
for(p1 = [1,siz[x]])
for(p3 = [p1,p1+siz[ver]-1])
g[x][p3]+=g[x][p1]*(g[ver][siz[ver]]-g[ver][p3-p1])*c[p3-1][p1-1]*c[siz[x]+siz[ver]-p3][siz[x]-p1];
- 若 (x > ver) 即 (x) 应在 (ver) 之后,(p_1>p_2) 。
for(p1 = [1,siz[x]])
for(p3 = [p1+1,p1+siz[ver]])
g[x][p3]+=g[x][p1]*g[ver][p3-p1]*c[p3-1][p1-1]*c[siz[x]+siz[ver]-p3][siz[x]-p1];
记得在最后加上:
for(int i=1;i<=siz[x];i++) g[x][i]+=g[x][i-1];
所以这道题 (AC) 了。
注:转移中为了维持方程的无后效性,应将 (f) 数组备份一遍。
(100pts) 完整代码:
#include<bits/stdc++.h>
using namespace std;
#define inf 0x3f3f3f3f
#define Maxn 1005
#define mod 1000000007
inline int rd()
{
int x=0;
char ch,t=0;
while(!isdigit(ch = getchar())) t|=ch=='-';
while(isdigit(ch)) x=x*10+(ch^48),ch=getchar();
return x=t?-x:x;
}
int t,n,tot;
int dp[Maxn][Maxn],tmp[Maxn],c[Maxn][Maxn],siz[Maxn];
int hea[Maxn],nex[Maxn<<1],ver[Maxn<<1],edg[Maxn<<1];
void add(int x,int y,int d)
{
ver[++tot]=y,nex[tot]=hea[x],hea[x]=tot,edg[tot]=d;
}
void dfs(int x,int fa)
{
siz[x]=1,dp[x][1]=1;
for(int i=hea[x];i;i=nex[i])
{
if(ver[i]==fa) continue;
dfs(ver[i],x);
memcpy(tmp,dp[x],sizeof(dp[x]));
memset(dp[x],0,sizeof(dp[x]));
if(edg[i])
for(int p1=1;p1<=siz[x];p1++)
for(int p3=p1;p3<=siz[ver[i]]+p1-1;p3++)
dp[x][p3]=(dp[x][p3]+1ll*tmp[p1]*(dp[ver[i]][siz[ver[i]]]-dp[ver[i]][p3-p1]+mod)%mod*c[p3-1][p1-1]%mod*c[siz[x]+siz[ver[i]]-p3][siz[x]-p1]%mod)%mod;
else
for(int p1=1;p1<=siz[x];p1++)
for(int p3=p1+1;p3<=siz[ver[i]]+p1;p3++)
dp[x][p3]=(dp[x][p3]+1ll*tmp[p1]*dp[ver[i]][p3-p1]%mod*c[p3-1][p1-1]%mod*c[siz[x]+siz[ver[i]]-p3][siz[x]-p1]%mod)%mod;
siz[x]+=siz[ver[i]];
}
for(int i=1;i<=siz[x];i++) dp[x][i]=(dp[x][i]+dp[x][i-1])%mod;
}
int main()
{
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
c[1][0]=c[0][0]=1;
for(int i=1;i<=1000;i++,c[i][0]=1)
for(int j=1;j<=i;j++)
c[i][j]=(c[i-1][j]+c[i-1][j-1])%mod;
t=rd();
char c;
while(t--)
{
n=rd();
memset(hea,0,sizeof(hea)),tot=0;
memset(siz,0,sizeof(siz));
for(int i=1,x,y;i<n;i++)
{
x=rd()+1,cin>>c,y=rd()+1;
add(x,y,c=='<'),add(y,x,c=='>');
}
dfs(1,0);
printf("%d
",dp[1][n]);
}
//fclose(stdin);
//fclose(stdout);
return 0;
}