zoukankan      html  css  js  c++  java
  • 最小斯坦纳树

    $dp[i][state]$ 表示以$i$为根,指定集合中的点的连通状态为state的生成树的最小总权值

    有两种转移方向:

    1、先通过连通状态的子集进行转移。

    2、在当前枚举的连通状态下,对该连通状态进行松弛操作。

    P4294 [WC2008]游览计划

    注意景点的个数不超过10个。

    $dp[i][j][state]$ 表示在$[i, j]$这个点与state中对应点连通的最小代价。

    那么就可以用状压DP + spfa求解。

    由于要输出方案,可以记录每个状态的前一个状态,最后dfs跑一遍就行了。

    // #pragma GCC optimize(2)
    // #pragma GCC optimize(3)
    // #pragma GCC optimize(4)
    #include <algorithm>
    #include  <iterator>
    #include  <iostream>
    #include   <cstring>
    #include   <cstdlib>
    #include   <iomanip>
    #include    <bitset>
    #include    <cctype>
    #include    <cstdio>
    #include    <string>
    #include    <vector>
    #include     <stack>
    #include     <cmath>
    #include     <queue>
    #include      <list>
    #include       <map>
    #include       <set>
    #include   <cassert>
    #include <unordered_set>
    #include <unordered_map>
    // #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    using namespace std;
    #define pb push_back
    #define fi first
    #define se second
    #define debug(x) cerr<<#x << " := " << x << endl;
    #define bug cerr<<"-----------------------"<<endl;
    #define FOR(a, b, c) for(int a = b; a <= c; ++ a)
    
    typedef long long ll;
    typedef long double ld;
    typedef pair<int, int> pii;
    typedef pair<ll, ll> pll;
    
    const int inf = 0x3f3f3f3f;
    const ll inff = 0x3f3f3f3f3f3f3f3f;
    const int mod = 1e9+7;
    
    template<typename T>
    inline T read(T&x){
        x=0;int f=0;char ch=getchar();
        while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
        while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
        return x=f?-x:x;
    }
    
    /**********showtime************/
                const int maxn = 12;
                int a[maxn][maxn];
                int dp[maxn][maxn][1055];
                struct node{
                    int x, y, state;
                } pre[maxn][maxn][1055];
    
                queue<pii>que;
                int nx[4][2] = {
                    {0, 1}, {1, 0},{-1,0},{0,-1}
                };
                int n,m;
                int vis[maxn][maxn];
                void spfa(int now) {
                    while(!que.empty()) {
                        pii tmp = que.front(); que.pop();
                        for(int i=0; i<4; i++) {
                            int x = tmp.fi + nx[i][0];
                            int y = tmp.se + nx[i][1];
                            if(x < 1 || x > n || y < 1 || y > m) continue;
                            if(dp[x][y][now] > dp[tmp.fi][tmp.se][now] + a[x][y]) {
                                dp[x][y][now] = dp[tmp.fi][tmp.se][now] + a[x][y];
                                pre[x][y][now] = node{tmp.fi, tmp.se, now};
                                if(!vis[x][y]) {
                                    que.push(pii(x, y));
                                    vis[x][y] = 1;
                                }
                            }
                        }
                        vis[tmp.fi][tmp.se] = 0;
                    }
                }
                void dfs(int x, int y, int now) {
                    if(x == 0 || y == 0) return;
                    vis[x][y] = 1;
                    node tmp = pre[x][y][now];
                    dfs(tmp.x, tmp.y, tmp.state);
                    if(tmp.x == x && tmp.y == y)
                        dfs(tmp.x, tmp.y, now - tmp.state);
                }
    int main(){
                scanf("%d%d", &n, &m);
                memset(dp, inf, sizeof(dp));
                int num = 0;
                for(int i=1; i<=n; i++) {
                    for(int j=1; j<=m; j++) {
                        scanf("%d", &a[i][j]);
                        if(a[i][j] == 0) {
                            dp[i][j][(1<<num)] = 0;
                            num++;
                        }
                    }
                }
                int all = (1<<num) - 1;
    
                for(int state = 0; state <= all; state ++) {
                    for(int i=1; i<=n; i++) {
                        for(int j=1; j<=m; j++) {
                            for(int s0 = (s0-1) & state; s0; s0 = (s0-1) & state) {
    
                                if(dp[i][j][state] > dp[i][j][s0] + dp[i][j][state - s0] - a[i][j]) {
                                    dp[i][j][state] = dp[i][j][s0] + dp[i][j][state - s0] - a[i][j];
                                    pre[i][j][state] = node{i, j, s0};
    
                                }
                            }
                            if(dp[i][j][state] < inf) que.push(pii(i, j)), vis[i][j] = 1;
                        }
                    }
                    spfa(state);
                }
                int ax, ay, mn = inf;
                for(int i=1; i<=n; i++) {
                    for(int j=1; j<=m; j++) {
                        if(dp[i][j][all] < mn) {
                            mn = dp[i][j][all];
                            ax = i;
                            ay = j;
                        }
                    }
                }
                printf("%d
    ", mn);
                memset(vis, 0, sizeof(vis));
                dfs(ax, ay, all);
                for(int i=1; i<=n; i++) {
                    for(int j=1; j<=m; j++) {
                        if(a[i][j] == 0) printf("x");
                        else if(vis[i][j]) printf("o");
                        else printf("_");
                    }
                    puts("");
                }
                return 0;
    }
    View Code

    HDU-4085 Peach Blossom Spring

     给定一个$n le 50 , m le 1000$ 的无向图,让你用最小的修路总花费,使得1到k号点($k le 5$),与最后k个点相连,就是说1到k号点每个点都有一个匹配点,匹配点两辆不同。

    斯坦纳树,但是题目没有要求这$2 imes k$个点都连通,所以我们利用斯坦纳小树dp转移,求出斯坦纳森林。

    // #pragma GCC optimize(2)
    // #pragma GCC optimize(3)
    // #pragma GCC optimize(4)
    #include <algorithm>
    #include  <iterator>
    #include  <iostream>
    #include   <cstring>
    #include   <cstdlib>
    #include   <iomanip>
    #include    <bitset>
    #include    <cctype>
    #include    <cstdio>
    #include    <string>
    #include    <vector>
    #include     <stack>
    #include     <cmath>
    #include     <queue>
    #include      <list>
    #include       <map>
    #include       <set>
    #include   <cassert>
    #include <unordered_set>
    #include <unordered_map>
    // #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    using namespace std;
    #define pb push_back
    #define fi first
    #define se second
    #define debug(x) cerr<<#x << " := " << x << endl;
    #define bug cerr<<"-----------------------"<<endl;
    #define FOR(a, b, c) for(int a = b; a <= c; ++ a)
    
    typedef long long ll;
    typedef long double ld;
    typedef pair<int, int> pii;
    typedef pair<ll, ll> pll;
    
    const int inf = 0x3f3f3f3f;
    const ll inff = 0x3f3f3f3f3f3f3f3f;
    const int mod = 1e9+7;
    
    template<typename T>
    inline T read(T&x){
        x=0;int f=0;char ch=getchar();
        while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
        while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
        return x=f?-x:x;
    }
    
    /**********showtime************/
                const int maxn = 55;
                vector<pii>mp[maxn];
                queue<int>que;
                int vis[55];
                int dp[maxn][1055],g[1055];
                void spfa(int now) {
                    while(!que.empty()) {
                        int u = que.front(); que.pop();
                        for(pii p : mp[u]) {
                            if(dp[p.fi][now] > dp[u][now] + p.se)
                            {
                                dp[p.fi][now] = dp[u][now] + p.se;
                                if(vis[p.fi] == 0) {
                                    vis[p.fi] = 1;
                                    que.push(p.fi);
                                }
                            }
                        }
                        vis[u] = 0;
                    }
                }
                int n,m,k;
                bool check(int state ){
                    int cnt = 0;
                    for(int i=1; i<=k; i++) {
                        if(state % 2 == 1) cnt++;
                        state = state / 2;
                    }
    
                    for(int i=1; i<=k; i++) {
                        if(state % 2 == 1) cnt--;
                        state = state / 2;
                    }
                    return cnt == 0;
                }
    int main(){
                int T;  scanf("%d", &T);
                while(T--) {
                    scanf("%d%d%d", &n, &m, &k);
                    for(int i=1; i<=m; i++) {
                        int u,v,w;
                        scanf("%d%d%d", &u, &v, &w);
                        mp[u].pb(pii(v, w));
                        mp[v].pb(pii(u, w));
                    }
                    int num = 2 * k;
                    int all = (1 << num) - 1;
                    for(int i=1; i<=n; i++) for(int state = 0; state <= all; state ++ ) dp[i][state] = inf;
                    for(int i=1; i<=k; i++) dp[i][(1<<(i-1))] = 0;
                    for(int i=n, cur = 2*k; i>=n-k+1; i--, cur--) {
                        dp[i][1<<(cur-1)] = 0;
                    }
                    for(int state=0; state <= all; state ++) {
                        for(int i=1; i<=n; i++) {
                            for(int s0 = (state-1)&state; s0; s0 = (s0-1) & state) {
                                dp[i][state] = min(dp[i][state], dp[i][s0] + dp[i][state - s0]);
                            }
                            if(dp[i][state] < inf) que.push(i), vis[i] = 1;
                        }
                        spfa(state);
                    }
                    
                    //由于最后没有要求得出一个斯坦纳树,而是一个斯坦纳森林,于是
                    //用小斯坦纳树组合一下
                    for(int state=0; state<=all; state++) {
                        g[state] = inf;
                        if(check(state)) {
                            for(int i=1; i<=n; i++)
                                g[state] = min(g[state], dp[i][state]);
                        }
                    }
                    
                    for(int state = 0; state <= all; state++) {
                        if(check(state) == 0) continue;
                        for(int s0 = (state-1)&state; s0; s0 = (s0-1) & state) {
                            if(check(s0)) {
                                g[state] = min(g[state], g[s0] + g[state - s0]);
                            }
                        }
                    }
    
                    if(g[all] < inf) printf("%d
    ", g[all]);
                    else printf("No solution
    ");
                    for(int i=1; i<=n; i++) mp[i].clear();
                }
                return 0;
    }
    View Code

    ZOJ-3613 Wormhole Transport

    题意:

    有n个星球,其中最多四个星球是资源星球,最多四个星球有不同个数的工厂。

    一个资源星球只能供给一个工厂。

    有不同类型的路可以修,问在最多工厂被供给的前提下,最小的修路费用。

    思路:

    朴素的斯坦纳树转移。

    然后需要通过森林DP,转移条件是工厂个数 $ge$ 资源个数

    // #pragma GCC optimize(2)
    // #pragma GCC optimize(3)
    // #pragma GCC optimize(4)
    #include <algorithm>
    #include  <iterator>
    #include  <iostream>
    #include   <cstring>
    #include   <cstdlib>
    #include   <iomanip>
    #include    <bitset>
    #include    <cctype>
    #include    <cstdio>
    #include    <string>
    #include    <vector>
    #include     <stack>
    #include     <cmath>
    #include     <queue>
    #include      <list>
    #include       <map>
    #include       <set>
    #include   <cassert>
    #include <unordered_set>
    #include <unordered_map>
    // #include<bits/extc++.h>
    // using namespace __gnu_pbds;
    using namespace std;
    #define pb push_back
    #define fi first
    #define se second
    #define debug(x) cerr<<#x << " := " << x << endl;
    #define bug cerr<<"-----------------------"<<endl;
    #define FOR(a, b, c) for(int a = b; a <= c; ++ a)
    
    typedef long long ll;
    typedef long double ld;
    typedef pair<int, int> pii;
    typedef pair<ll, ll> pll;
    
    const int inf = 0x3f3f3f3f;
    const ll inff = 0x3f3f3f3f3f3f3f3f;
    const int mod = 1e9+7;
    
    template<typename T>
    inline T read(T&x){
        x=0;int f=0;char ch=getchar();
        while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar();
        while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
        return x=f?-x:x;
    }
    
    /**********showtime************/
    
                const int maxn = 205;
                int pt[10], flag[10];
                int vis[maxn];
                queue<int>que;
                pii p[maxn];
                vector<pii>mp[maxn];
                int dp[maxn][300];
                int g[300];
                void spfa(int now) {
                    while(!que.empty()) {
                        int u = que.front(); que.pop();
                        for(pii p : mp[u]) {
                            if(dp[p.fi][now] > dp[u][now] + p.se) {
                                dp[p.fi][now] = dp[u][now] + p.se;
                                if(vis[p.fi] == 0) {
                                    que.push(p.fi);
                                    vis[p.fi] = 1;
                                }
                            }
                        }
                        vis[u] = 0;
                    }
                }
                int cnt = 0;
                bool check(int now) {
                    int c[2];
                    c[0] = c[1] = 0;
                    for(int i=0; i<cnt; i++) {
                        if(now % 2 == 1) {
                            if(flag[i])
                                c[1]++;
                            else c[0] += pt[i];
                        }
                        now = now / 2;
                    }
                    return c[1] <= c[0];
                }
                int cal(int now) {
                    int res = 0;
                    for(int i=0; i<cnt; i++) {
                        if(now % 2 == 1) {
                            res += flag[i];
                        }
                        now = now / 2;
                    }
                    return res;
                }
    int main(){
                int n;
                while(~scanf("%d", &n)) {
                    for(int i=1; i<=n; i++) {
                        for(int j=0; j<300; j++)
                            dp[i][j] = inf;
                    }
                    cnt = 0;
                    int res1 = 0;
                    for(int i=1; i<=n; i++) {
                        scanf("%d%d", &p[i].fi, &p[i].se);
                        if(p[i].fi && p[i].se)
                        {
                            res1++;
                            p[i].se = 0;
                            p[i].fi--;
                        }
                        if(p[i].se) {
                            flag[cnt] = 1; //
                            pt[cnt] = p[i].fi;
                            dp[i][1<<cnt] = 0;
                            cnt++;
                        }
                        else if(p[i].fi) {
                            pt[cnt] = p[i].fi;
                            flag[cnt] = 0;
                            dp[i][1<<cnt] = 0;
                            cnt++;
                        }
                    }
    
                    int m;  scanf("%d", &m);
                    for(int i=1; i<=m; i++) {
                        int u,v,w;
                        scanf("%d%d%d", &u, &v, &w);
                        mp[u].pb(pii(v, w));
                        mp[v].pb(pii(u, w));
                    }
    
                    int all = (1 << cnt) - 1;
                    for(int state = 0; state <= all; state++) {
                        for(int i=1; i<=n; i++) {
                            for(int s0 = (state-1) & state; s0; s0 = (s0-1) & state) {
                                dp[i][state] = min(dp[i][state], dp[i][s0] + dp[i][state - s0]);
                            }
                            if(dp[i][state] < inf) que.push(i);
                        }
                        spfa(state);
                    }
                    for(int state=0; state<=all; state++) {
                        g[state] = inf;
                        if(check(state) == 0) continue;
                        for(int i=1; i<=n; i++) {
                            g[state] = min(g[state], dp[i][state]);
                        }
                    }
    
                    int res2 = 0, ans = 0;
                    for(int state = 0; state <= all; state ++) {
                        if(check(state) == 0) continue;
                        for(int s0 = (state - 1) & state; s0; s0 = (s0-1) & state) {
                            if(check(s0) && check(state - s0)) {
                                g[state] = min(g[state], g[s0] + g[state - s0]);
                            }
                        }
    
                        if(cal(state) > res2) {
                            res2 = cal(state);
                            ans = g[state];
                        }
                        else if(cal(state) == res2) {
                            ans = min(ans, g[state]);
                        }
                    }
                    printf("%d %d
    ", res1 + res2, ans);
    
                    for(int i=1; i<=n; i++) mp[i].clear();
                }
                return 0;
    }
    View Code

    参考和学习:

    https://www.cnblogs.com/clno1/p/10990936.html

  • 相关阅读:
    WebSocket简单通信
    python必会内置函数
    python装饰器
    Python常用模块1
    python函数操作
    python字典操作
    python切片操作
    python列表操作
    python字符串格式化的几种方式
    Jmeter响应中中文乱码怎么解决?
  • 原文地址:https://www.cnblogs.com/ckxkexing/p/11505782.html
Copyright © 2011-2022 走看看