zoukankan      html  css  js  c++  java
  • 「ZJOI2019」Minimax 搜索(动态dp)

    Address

    loj3044

    Solution

    考虑对 \(k\in [l-1,r]\) 分别求出有多少个集合 \(S\) 满足 \(w(S)\le k\),记作 \(ans_k\)

    先求出 \(1\) 的初始权值 \(W\)

    \(val(x)\) 表示 \(x\) 的权值。枚举 \(k\),现在对于每个叶子 \(u\),如果 \(u\in S\),那么 \(val(u)\in [u-W,u+W]\),否则 \(val(u)=W\)

    我们发现,把叶子节点的权值改成 \(W\) 肯定是不优的。所以改动一些叶子后,如果 \(val(1)\) 还是 \(W\),那么肯定路径 \(1→W\) 上每个点的权值都是 \(W\),且其它的点的权值都不是 \(W\)

    因此,如果想要 \(val(1)\) 改变,那么路径 \(1→W\) 上肯定存在一个点 \(x\)\(val(x)\ne W\)。记 \(x\) 在路径 \(1→W\) 上的子节点为 \(y\)。如果 \(x\) 深度是奇数, 那么肯定存在一个 \(x\) 的子节点 \(z(z\ne y)\)\(val(z)>W\)\(x\) 深度是偶数时同理。

    我们把 \(1→W\) 上的边全部断掉,再求一遍每个点的权值。如果原路径 \(1→W\) 上存在某个深度为奇数的点的权值 \(>W\),或者某个深度为偶数的点的权值 \(<W\),那么 \(val(1)\) 肯定改变,否则肯定不变。

    \(f(u)\) 表示 \(u\) 子树中,使 \(val(u)>w\) 的合法叶子节点集合有几个。\(g(u)\) 表示 \(u\) 子树中,使 \(val(u)<w\) 的合法叶子节点集合有几个。

    如果 \(u\) 是叶子节点:\(f(u)=[u>W]+[u+k>W],g(u)=[u<W]+[u-k<W]\)。其中 \([u>W],[u<W]\) 表示 \(u\) 不在叶子节点集合内,\([u+k>W],[u-k<W]\) 表示在集合内。

    如果 \(u\) 是深度为奇数的非叶子节点,如果 \(val(u)>W\),那么 \(u\) 的子节点最大权值必须 \(>W\),也就是说不能全部 \(\le W\)。因此 \(f(u)=2^{cnt_u}\prod_{v\in son_u}(2^{cnt_v}-f(v))\)。其中 \(cnt_u\) 表示 \(u\) 的子树内有几个叶子节点。

    如果 \(u\) 是深度为偶数的非叶子节点,如果 \(val(u)>W\),那么 \(u\) 的子节点全部 \(<W\)。因此 \(f(u)=\prod_{v\in son_u}f(v)\)

    \(g\) 的转移和 \(f\) 类似。

    接下来求 \(ans_k\)。考虑补集转化,即用 \(2^{cnt_1}\) 减去不会让 \(val(1)\) 改变的集合数。不会让 \(val(1)\) 改变,就是要让原路径 \(1→W\) 上的每个点的权值都不变。那么把深度为奇数的 \(2^{cnt_x}-f_x\) 和深度为偶数的 \(2^{cnt_x}-g_x\) 全部相乘就是答案了。

    至此,我们得到了一个 \(O(n(R-L))\) 的做法。

    考虑优化,我们发现转移与 \(k\) 无关,只有叶子节点的 \(f,g\)\(k\) 有关。进一步地,我们发现随着 \(k\) 变大,每个叶子节点的 \(f,g\) 最多改变一次。因此可以看作是 \(O(n)\) 次修改的动态 \(dp\),时间复杂度 \(O(n\log^2 n)\)

    Code

    #include <bits/stdc++.h>
    
    using namespace std;
    
    #define ll long long
    #define p2 p << 1
    #define p3 p << 1 | 1
    
    template <class t>
    inline void read(t & res)
    {
        char ch;
        while (ch = getchar(), !isdigit(ch));
        res = ch ^ 48;
        while (ch = getchar(), isdigit(ch))
        res = res * 10 + (ch ^ 48);
    }
    
    template <class t>
    inline void print(t x)
    {
        if (x > 9) print(x / 10);
        putchar(x % 10 + 48);
    }
    
    const int e = 2e5 + 5, mod = 998244353;
    
    struct point
    {
        int x, y;
    }b[e], que[e];
    struct matrix
    {
        int a, b;
        
        matrix(){}
        matrix(int _a, int _b) :
            a(_a), b(_b) {}
    }tr[e << 2];
    vector<int>g[e], c[e], d[e];
    int f[e], dep[e], L, R, w, n, fa[e], a[e], m, nxt[e], go[e], adj[e], val[e], K, cnt[e], f2[e];
    int q[e], h[e], num, all, sum[e << 2], son[e], sze[e], dfnA[e], dfnB[e], timA, timB, idA[e], idB[e];
    int st[e], ed[e], bot[e], top[e], ans[e], rt[e], now_rt;
    bool is[e], op, bo[e];
    
    inline void add(int &x, int y)
    {
        (x += y) >= mod && (x -= mod);
    }
    
    inline void del(int &x, int y)
    {
        (x -= y) < 0 && (x += mod);
    }
    
    inline int plu(int x, int y)
    {
        add(x, y);
        return x;
    }
    
    inline int sub(int x, int y)
    {
        del(x, y);
        return x;
    }
    
    inline int mul(int x, int y)
    {
        return (ll)x * y % mod;
    }
    
    inline int ksm(int x, int y)
    {
        int res = 1;
        while (y)
        {
            if (y & 1) res = mul(res, x);
            y >>= 1;
            x = mul(x, x);
        }
        return res;
    }
    
    inline matrix operator + (matrix u, matrix v)
    {
        return matrix(mul(u.a, v.a), plu(mul(u.b, v.a), v.b));
    }
    
    inline void link1(int x, int y)
    {
        g[x].push_back(y);
        g[y].push_back(x);
    }
    
    inline void link2(int x, int y)
    {
        nxt[++num] = adj[x]; adj[x] = num; go[num] = y;
    }
    
    inline void dfs1(int u, int pa)
    {
        dep[u] = dep[pa] + 1;
        fa[u] = pa;
        if (dep[u] & 1) val[u] = 0;
        else val[u] = n + 1;
        int len = g[u].size(), i;
        bool pd = 0;
        for (i = 0; i < len; i++)
        {
            int v = g[u][i];
            if (v == pa) continue;
            pd = 1;
            dfs1(v, u);
            if (dep[u] & 1) val[u] = max(val[u], val[v]);
            else val[u] = min(val[u], val[v]);
        }
        if (!pd) val[u] = u, all++;   
    }
    
    inline void dfs2(int u)
    {
        if (val[u] == u)
        {
            if (op) 
            {
                f[u] = (u > w) + (u + K > w);
                if (L <= w + 1 - u && w + 1 - u <= R) c[w + 1 - u].push_back(u);
            }
            else 
            {
                f[u] = (u < w) + (u - K < w);
                if (L <= u + 1 - w && u + 1 - w <= R) d[u + 1 - w].push_back(u);
            }
            return;
        }
        f[u] = f2[u] = 1;
        bool fl = ((dep[u] & 1) && op) || ((~dep[u] & 1) && !op);
        bo[u] = fl;
        for (int i = adj[u]; i; i = nxt[i])
        {
            int v = go[i];
            dfs2(v);
            if (fl) f[u] = mul(f[u], sub(q[v], f[v]));
            else f[u] = mul(f[u], f[v]);
            if (v != son[u])
            {
                if (fl) f2[u] = mul(f2[u], sub(q[v], f[v]));
                else f2[u] = mul(f2[u], f[v]);
            }
        }
        if (fl) f[u] = sub(q[u], f[u]);
    }
    
    inline void dfs3(int u)
    {
        if (val[u] == u) cnt[u] = 1;
        sze[u] = 1;
        rt[u] = now_rt;
        for (int i = adj[u]; i; i = nxt[i])
        {
            int v = go[i];
            dfs3(v);
            cnt[u] += cnt[v];
            sze[u] += sze[v];
            if (sze[v] > sze[son[u]]) son[u] = v;
        }
    }
    
    inline void dfs4(int u, int fi)
    {
        top[u] = fi;
        dfnA[u] = ++timA;
        idA[timA] = u;
        if (son[u]) 
        {
            dfs4(son[u], fi);
            st[u] = timB + 1;
            for (int i = adj[u]; i; i = nxt[i])
            {
                int v = go[i];
                if (v == son[u]) continue;
                dfnB[v] = ++timB;
                idB[timB] = v;
            }
            ed[u] = timB;
        }
        for (int i = adj[u]; i; i = nxt[i])
        {
            int v = go[i];
            if (v == son[u]) continue;
            dfs4(v, v);
        }
        if (son[u]) bot[u] = bot[son[u]];
        else bot[u] = u;
    }
    
    inline void build(int l, int r, int p)
    {
        if (l == r)
        {
            int u = idA[l], v = idB[l];
            if (son[u])
            {
                if (bo[u])
                {
                    int v = son[u];
                    tr[p] = matrix(f2[u], sub(q[u], mul(f2[u], q[v])));
                }
                else tr[p] = matrix(f2[u], 0);
            }
            if (v)
            {
                int pa = fa[v];
                if (bo[pa]) sum[p] = sub(q[v], f[v]);
                else sum[p] = f[v];
            }
            return;
        }
        int mid = l + r >> 1;
        build(l, mid, p2);
        build(mid + 1, r, p3);
        tr[p] = tr[p3] + tr[p2];
        sum[p] = mul(sum[p2], sum[p3]);
    }
    
    inline void upt_tr(int l, int r, int s, matrix u, int p)
    {
        if (l == r)
        {
            tr[p] = u;
            return;
        }
        int mid = l + r >> 1;
        if (s <= mid) upt_tr(l, mid, s, u, p2);
        else upt_tr(mid + 1, r, s, u, p3);
        tr[p] = tr[p3] + tr[p2];
    }
    
    inline void upt_sum(int l, int r, int s, int v, int p)
    {
        if (l == r)
        {
            sum[p] = v;
            return;
        }
        int mid = l + r >> 1;
        if (s <= mid) upt_sum(l, mid, s, v, p2);
        else upt_sum(mid + 1, r, s, v, p3);
        sum[p] = mul(sum[p2], sum[p3]);
    }
    
    inline matrix ask_tr(int l, int r, int s, int t, int p)
    {
        if (l == s && r == t) return tr[p];
        int mid = l + r >> 1;
        if (t <= mid) return ask_tr(l, mid, s, t, p2);
        else if (s > mid) return ask_tr(mid + 1, r, s, t, p3);
        else return ask_tr(mid + 1, r, mid + 1, t, p3) + ask_tr(l, mid, s, mid, p2);
    }
    
    inline int ask_sum(int l, int r, int s, int t, int p)
    {
        if (l == s && r == t) return sum[p];
        int mid = l + r >> 1;
        if (t <= mid) return ask_sum(l, mid, s, t, p2);
        else if (s > mid) return ask_sum(mid + 1, r, s, t, p3);
        else return mul(ask_sum(l, mid, s, mid, p2), ask_sum(mid + 1, r, mid + 1, t, p3));
    }
    
    inline void pair_mul(point &u, int x)
    {
        if (!x) u.y++;
        else u.x = mul(u.x, x);
    }
    
    inline void pair_div(point &u, int x)
    {
        if (!x) u.y--;
        else u.x = mul(u.x, ksm(x, mod - 2));
    }
    
    inline void cover(int &x, point u)
    {
        int res = u.x;
        if (u.y) res = 0;
        x = sub(all, res);
    }
    
    inline int calc(int x, matrix c)
    {
        return plu(mul(x, c.a), c.b);
    }
    
    inline int ask(int x)
    {
        if (x == bot[x]) return f[x];
        int l = dfnA[x], r = dfnA[bot[x]] - 1;
        return calc(f[bot[x]], ask_tr(1, n, l, r, 1));
    }
    
    inline void change(int x)
    {
        pair_div(que[K], sub(q[rt[x]], f[rt[x]]));
        x = top[x];
        while (x)
        {
            f[x] = ask(x);
            if (!fa[x]) break;
            int y = fa[x];
            if (bo[y]) upt_sum(1, n, dfnB[x], sub(q[x], f[x]), 1);
            else upt_sum(1, n, dfnB[x], f[x], 1);
            f2[y] = ask_sum(1, n, st[y], ed[y], 1);
            
            matrix tmp;
            if (bo[y])
            {
                int v = son[y];
                tmp = matrix(f2[y], sub(q[y], mul(f2[y], q[v])));
            }
            else tmp = matrix(f2[y], 0);
            upt_tr(1, n, dfnA[y], tmp, 1);
            
            x = top[y];
        }
        pair_mul(que[K], sub(q[x], f[x]));
    }
    
    int main()
    {
        freopen("minimax.in", "r", stdin);
        freopen("minimax.out", "w", stdout);
        read(n); read(L); read(R);
        int i, x, y, j;
        for (i = 1; i < n; i++) read(x), read(y), link1(x, y), b[i].x = x, b[i].y = y;
        dfs1(1, 0);
        x = w = val[1];
        h[0] = 1;
        for (i = 1; i <= n; i++) h[i] = plu(h[i - 1], h[i - 1]);
        while (x != 1)
        {
            a[++m] = x;
            x = fa[x];
        }
        a[++m] = 1;
        reverse(a + 1, a + m + 1);
        for (i = 1; i <= m; i++) is[a[i]] = 1;
        for (i = 1; i < n; i++)
        {
            x = b[i].x; y = b[i].y;
            if (!is[x] || !is[y])
            {
                if (fa[x] == y) link2(y, x);
                else link2(x, y);
            }
        }
        for (i = 1; i <= m; i++) now_rt = a[i], dfs3(a[i]);
        all = h[all];
        for (i = 1; i <= n; i++) q[i] = h[cnt[i]];
        for (i = 1; i <= m; i++) dfs4(a[i], a[i]), fa[a[i]] = 0;
        
        bool flag = 0;
        if (L == 1) K = L, flag = 1, L++;
        else K = L - 1;
        que[K].x = 1;
        for (j = 1; j <= m; j++) 
        {
            int u = a[j];
            op = j & 1; dfs2(u);
            pair_mul(que[K], sub(q[u], f[u]));
        }
        cover(ans[K], que[K]);
        build(1, n, 1);
        for (i = L; i <= R; i++)
        {
            que[i] = que[i - 1];
            K = i;
            int lenc = c[i].size(), lend = d[i].size();
            for (j = 0; j < lenc; j++)
            {
                int u = c[i][j];
                f[u] = (u > w) + (u + K > w);
                change(u);
            }
            for (j = 0; j < lend; j++)
            {
                int u = d[i][j];
                f[u] = (u < w) + (u - K < w);
                change(u);
            }
            if (i == n)
            {
                ans[i] = sub(all, 1);
                continue;
            }
            cover(ans[i], que[i]);
        }
        if (flag) L--;
        for (i = L; i <= R; i++) 
        print(sub(ans[i], ans[i - 1])), putchar(i == R ? '\n' : ' ');
        return 0;
    }
    
  • 相关阅读:
    ubuntu安装jdk的两种方法
    LeetCode 606. Construct String from Binary Tree (建立一个二叉树的string)
    LeetCode 617. Merge Two Binary Tree (合并两个二叉树)
    LeetCode 476. Number Complement (数的补数)
    LeetCode 575. Distribute Candies (发糖果)
    LeetCode 461. Hamming Distance (汉明距离)
    LeetCode 405. Convert a Number to Hexadecimal (把一个数转化为16进制)
    LeetCode 594. Longest Harmonious Subsequence (最长的协调子序列)
    LeetCode 371. Sum of Two Integers (两数之和)
    LeetCode 342. Power of Four (4的次方)
  • 原文地址:https://www.cnblogs.com/cyf32768/p/12296954.html
Copyright © 2011-2022 走看看