题目大意就是 给你一个DAG
然后添加一条边(x->y) ,询问以1为根的生成树的个数
QWQ
首先假设没有添加的边
答案就应该是
[ans=prod_{i=1}^{n} in[i]
]
QWQ就相当于每个点选择一个父亲。
那么加入一条边,我们会有一些不合法的情况,那就是包含一条(y->x)路径,剩下随便选的方案数。假设全集是(C),然后路径上的点的集合是(S),那我们实际上求的就是$$frac{F(C)}{F(S)}$$
其中(F(S))表示(S)集合中所有点的入度的乘积
然后对于这个东西,我们可以考虑拓扑图上dp的方式
来解决
//假设我们添加了一条x->y的边,要想不合法,就是求y->x的路径条数
//所以我们要将令起点,也就是y的初值f[y]=ans
void addedge(int x,int y)
{
nxt[++cnt]=point[x];
to[cnt]=y;
in[y]++;
point[x]=cnt;
}
int qsm(int i,int j)
{
int ans=1;
while (j)
{
if (j&1) ans=ans*i%mod;
i=i*i%mod;
j>>=1;
}
return ans;
}
void tpsort()
{
//cout<<ans<<endl;
for (int i=1;i<=n;i++)
{
if (!in[i]) q.push(i);
}
while (!q.empty())
{
int now = q.front();
q.pop();
//cout<<now<<endl;
//int ymh=0;
//if (now==y) ymh=1;
f[now]=f[now]*qsm(d[now],mod-2)%mod;
//cout<<now<<" "<<f[now]<<endl;
for (int i=point[now];i;i=nxt[i])
{
int p =to[i];
in[p]--;
f[p]=(f[p]+f[now])%mod;
if (!in[p]) q.push(p);
}
}
}
下面是整个的代码
// luogu-judger-enable-o2
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define mk makr_pair
#define ll long long
#define int long long
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn = 2e5+1e2;
const int maxm = 2*maxn;
const int mod = 1e9+7;
int point[maxn],nxt[maxm],to[maxm];
int n,m;
int cnt,in[maxn];
queue<int> q;
int ans;
int f[maxn];
int x,y;
int d[maxn];
//假设我们添加了一条x->y的边,要想不合法,就是求y->x的路径条数
//所以我们要将令起点,也就是y的初值f[y]=ans
void addedge(int x,int y)
{
nxt[++cnt]=point[x];
to[cnt]=y;
in[y]++;
point[x]=cnt;
}
int qsm(int i,int j)
{
int ans=1;
while (j)
{
if (j&1) ans=ans*i%mod;
i=i*i%mod;
j>>=1;
}
return ans;
}
void tpsort()
{
//cout<<ans<<endl;
for (int i=1;i<=n;i++)
{
if (!in[i]) q.push(i);
}
while (!q.empty())
{
int now = q.front();
q.pop();
//cout<<now<<endl;
//int ymh=0;
//if (now==y) ymh=1;
f[now]=f[now]*qsm(d[now],mod-2)%mod;
//cout<<now<<" "<<f[now]<<endl;
for (int i=point[now];i;i=nxt[i])
{
int p =to[i];
in[p]--;
f[p]=(f[p]+f[now])%mod;
if (!in[p]) q.push(p);
}
}
}
signed main()
{
n=read(),m=read(),x=read(),y=read();
for (int i=1;i<=m;i++)
{
int u=read(),v=read();
addedge(u,v);
}
ans=1;
for (int i=2;i<=n;i++)
{
if (i==y) ans=ans*(in[i]+1)%mod,d[i]=in[i]+1;
else ans=ans*in[i]%mod,d[i]=in[i];
}
f[y]=ans;
if (x==1)
{
cout<<ans<<"
";
return 0;
}
tpsort();
cout<<(ans-f[x]+mod)%mod<<endl;
return 0;
}