从S出发跑dij,从T出发跑dij,顺便最短路计数。
令$F(x)$为$S$到$T$最短路经过$x$的方案数,显然这个是可以用$S$到$x$的方案数乘$T$到$x$的方案数来得到。
然后第一个条件就变成了满足$F(A)+F(B)=F(T)$,这个只要用map存一下点的状态,每次查$F(T)-F(A)$就可以得到$B$的状态了。
第二个条件实际上就是$A$无法到达$B$,怎么判断这个呢。
按最短路正反拓扑排序两次,分别按两种拓扑序做$O(n*m/32)$的传递闭包,然后一个点两种(按拓扑序得到的能到达的点的状态的补集)的交集就是不能到达的点了。
统计答案的时候找map里$F(T)-F(A)$的状态 & $A$两种(按拓扑序得到的能到达的点的状态的补集)的交集,用bitset::count求出有几个1就好了,记得判一下算重复的情况。
第一次学会在map里开bitset...
还有$S$不能到达$T$要输出$n*(n-1)/2$...= =

#include<iostream> #include<cstring> #include<cstdlib> #include<cstdio> #include<bitset> #include<queue> #include<map> #define ll long long using namespace std; const int maxn=50010; const ll inf=1e15; struct tjm{int too, dis, pre;}e[maxn<<1]; struct poi{int x; ll dis;}; priority_queue<poi>q; bool operator<(poi a, poi b){return a.dis>b.dis;} map<ll,bitset<maxn> >mp; int n, m, s, t, x, y, z, tot, cnt, top; int p[maxn], last[maxn], ru[maxn], pos[maxn], st[maxn]; ll ans, dist[2][maxn], f[2][maxn]; bitset<maxn>g[2][maxn]; bool v[maxn]; inline void read(int &k) { int f=1; k=0; char c=getchar(); while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar(); while(c<='9' && c>='0') k=k*10+c-'0', c=getchar(); k*=f; } inline void add(int x, int y, int z){e[++tot]=(tjm){y, z, last[x]}; last[x]=tot;} inline void dijkstra(int x, int ty) { for(int i=1;i<=n;i++) dist[ty][i]=inf; dist[ty][x]=0; f[ty][x]=1; q.push((poi){x, 0}); while(!q.empty()) { poi now=q.top(); q.pop(); if(now.dis!=dist[ty][now.x]) continue; for(int i=last[now.x], too;i;i=e[i].pre) if(dist[ty][too=e[i].too]>dist[ty][now.x]+e[i].dis) { f[ty][too]=f[ty][x]; dist[ty][too]=min(inf, dist[ty][now.x]+e[i].dis); q.push((poi){too, dist[ty][too]}); } else if(dist[ty][too]==dist[ty][now.x]+e[i].dis) f[ty][too]+=f[ty][x]; } } inline bool check(int x, int y, int dis, int ty){return v[y] && dist[ty][x]+dis==dist[ty][y];} inline void topsort(int ty) { memset(ru, 0, sizeof(ru)); top=0; for(int i=1;i<=cnt;i++) for(int j=last[p[i]], too;j;j=e[j].pre) if(check(p[i], too=e[j].too, e[j].dis, ty)) ru[too]++; for(int i=1;i<=cnt;i++) if(!ru[p[i]]) st[++top]=p[i], pos[p[i]]=top; for(int i=1;i<=top;i++) for(int j=last[st[i]], too;j;j=e[j].pre) if(check(st[i], too=e[j].too, e[j].dis, ty)) { ru[too]--; if(!ru[too]) st[++top]=too, pos[too]=top; } for(int i=1;i<=cnt;i++) g[ty][p[i]][p[i]-1]=1; for(int i=top;i;i--) for(int j=last[st[i]], too;j;j=e[j].pre) if(check(st[i], too=e[j].too, e[j].dis, ty) && pos[st[i]]<pos[too]) g[ty][st[i]]|=g[ty][too]; } int main() { read(n); read(m); read(s); read(t); for(int i=1;i<=m;i++) read(x), read(y), read(z), add(x, y, z), add(y, x, z); dijkstra(s, 0); if(dist[0][t]==inf) return printf("%lld ", 1ll*n*(n-1)>>1), 0; dijkstra(t, 1); for(int i=1;i<=n;i++) if(dist[0][i]+dist[1][i]==dist[0][t]) p[++cnt]=i, v[i]=1; for(int i=1;i<=cnt;i++) mp[f[0][p[i]]*f[1][p[i]]]|=1<<(p[i]-1); topsort(0); topsort(1); for(int i=1;i<=cnt;i++) ans+=(((mp[f[0][t]-f[0][p[i]]*f[1][p[i]]])>>(i-1))&(~g[0][p[i]]>>(i-1))&(~g[1][p[i]]>>(i-1))).count(); ll tmp=0; for(int i=1;i<=cnt;i++) if(f[0][p[i]]*f[1][p[i]]==f[0][t]) tmp++; ans+=tmp*(n-cnt); printf("%lld ", ans); }