Description
给定一个n个点m条边的有向图,有k个标记点,要求从规定的起点按任意顺序经过所有标记点到达规定的终点,问最短的距离是多少。
Input
第一行5个整数n、m、k、s、t,表示点个数、边条数、标记点个数、起点编号、终点编号。
接下来m行每行3个整数x、y、z,表示有一条从x到y的长为z的有向边。
接下来k行每行一个整数表示标记点编号。
Output
输出一个整数,表示最短距离,若没有方案可行输出-1。
Sample Input
3 3 2 1 1
1 2 1
2 3 1
3 1 1
2
3
Sample Output
3
【样例解释】
路径为1->2->3->1。
Data Constraint
20%的数据n<=10。
50%的数据n<=1000。
另有20%的数据k=0。
100%的数据n<=50000,m<=100000,0<=k<=10,1<=z<=5000。
从起点,必经点作最短路
之后用这些 dst 数组跑 dp
f[i][j] 表示经过必经点的集合为 i ,最后一个点为 j 时,最短的经过路程
更新一个状态之后 dfs 一路更新下去,直到不能更新为止,否则会漏状态(不 dfs 也能 A,不过经打印测试,会有状态没被更新)
代码:
#include<algorithm> #include<iostream> #include<cstring> #include<cstdlib> #include<cctype> #include<cstdio> #include<bitset> #include<queue> #define INF 2000000000000000ll using namespace std; typedef long long ll; const int MAXN = 50005, MAXM = 100005; struct EDGE{ int nxt, to; ll val; EDGE(int NXT = 0, int TO = 0, ll VAL = 0ll) {nxt = NXT; to = TO; val = VAL;} }edge[MAXM]; int n, m, k, s, t, totedge, maxs; int head[MAXN], dest[15]; ll dst[15][MAXN], f[1 << 10][15]; bitset<MAXN> vis; inline int rd() { register int x = 0; register char c = getchar(); while(!isdigit(c)) c = getchar(); while(isdigit(c)) { x = x * 10 + (c ^ 48); c = getchar(); } return x; } inline void add(int x, int y, ll v) { edge[++totedge] = EDGE(head[x], y, v); head[x] = totedge; return; } inline void dij(int bgn, ll d[MAXN]) { vis.reset(); for(int i = 1; i <= n; ++i) d[i] = INF; priority_queue<pair<ll, int> > q; d[bgn] = 0ll; q.push(make_pair(0ll, bgn)); while(!q.empty()) { int x = q.top().second; q.pop(); if(vis.test(x)) continue; vis.set(x); for(int i = head[x]; i; i = edge[i].nxt) { int y = edge[i].to; if(d[x] + edge[i].val < d[y]) { d[y] = d[x] + edge[i].val; q.push(make_pair(-d[y], y)); } } } return; } void dfs(int s, int cur) { if(s == maxs - 1) return; for(int i = 1; i <= k; ++i) if(i != cur && (s & (1 << (i - 1))) == 0) { if(f[s | (1 << (i - 1))][i] > f[s][cur] + dst[cur][dest[i]]) { f[s | (1 << (i - 1))][i] = f[s][cur] + dst[cur][dest[i]]; dfs((s | (1 << (i - 1))), i); } } return; } int main() { n = rd(); m = rd(); k = rd(); s = rd(); t = rd(); register int xx, yy, vv; for(int i = 1; i <= m; ++i) { xx = rd(); yy = rd(); vv = (ll)rd(); if(xx != yy) add(xx, yy, vv); } dij(s, dst[k + 1]); if(!k) { if(dst[k + 1][t] == INF) puts("-1"); else printf("%lld ", dst[k + 1][t]); return 0; } for(int i = 1; i <= k; ++i) { dest[i] = rd(); dij(dest[i], dst[i]); } maxs = (1 << k); for(int i = 1; i <= k; ++i) for(int j = 0; j < maxs; ++j) f[j][i] = INF; for(int i = 1; i <= k; ++i) f[1 << (i - 1)][i] = dst[k + 1][dest[i]]; for(int s = 0; s < maxs; ++s) { for(int i = 1; i <= k; ++i) if(f[s][i] != INF) dfs(s, i); } ll ans = INF; for(int i = 1; i <= k; ++i) { f[maxs - 1][i] += dst[i][t]; ans = min(ans, f[maxs - 1][i]); } if(ans == INF) puts("-1"); else printf("%lld ", ans); return 0; }