zoukankan      html  css  js  c++  java
  • 【数据结构】树上启发式合并

    https://codeforces.com/contest/375/problem/D

    这道题也有用树上莫队的解法,这里再给一个树上启发式合并的解法。

    树上启发式合并,类似轻重链剖分,先算出每个节点的重儿子,然后计算答案时先递归计算轻儿子的答案,标记clr为true,然后计算重儿子的答案,clr为false,然后把轻儿子的节点暴力插到重儿子里,最后把父节点自己加入。

    int n;
    vector<int> G[MAXN];
    int siz[MAXN], mch[MAXN];
    
    void dfs1(int u, int p) {
        siz[u] = 1, mch[u] = 0;
        for (int &v : G[u]) {
            if (v == p)
                continue;
            dfs1(v, u);
            siz[u] += siz[v];
            if (siz[mch[u]] < siz[v])
                mch[u] = v;
        }
    }
    

    由于树上启发式合并并不关心深度,所以没有必要维护深度。

    void calc(int u, int p, int skip, int d) {
        bit.Add(cnt[c[u]], -1);
        cnt[c[u]] += d;
        bit.Add(cnt[c[u]], 1);
        for (int v : G[u]) {
            if (v == p || v == skip)
                continue;
            calc(v, u, 0, d);
        }
    }
    
    void dfs2(int u, int p, bool keep) {
        for (int &v : G[u]) {
            if (v == p || v == mch[u])
                continue;
            dfs2(v, u, false);
        }
        if (mch[u])
            dfs2(mch[u], u, true);
        calc(u, p, mch[u], 1);
        for (pii &q : Q[u]) {
            int id = q.first, k = q.second;
            ans[q.first] = bit.Sum(k, n);
        }
        if (!keep)
            calc(u, p, 0, -1);
    }
    

    然后是主要的计算过程dfs2,dfs2优先进入所有的轻儿子,并且不keep轻儿子的答案,保持树状数组为空。然后进入重儿子计算并keep重儿子的结果。这里使用一个辅助函数calc,calc的修改值为1时表示向树状数组中添加,然后命令其在添加时skip掉重儿子。计算完毕后树状数组中存着这棵子树对应的状态,然后取出所有的询问进行回答。那之后,假如不keep树状数组,调用calc修改值为-1,并且不跳过重儿子,把整棵子树删除干净。

    时间复杂度为 (O(nlog^2n))

    https://codeforces.com/gym/102832/problem/F

    这里的查询要去重,所以要先计算再查询。而且要注意cache的命中。一次树遍历就统计出所有的信息,把常用的局部值放在数组的低维。

    int n, k;
    int a[MAXN];
    
    vector<int> G[MAXN];
    int siz[MAXN], mch[MAXN];
    
    void dfs1(int u, int p) {
        siz[u] = 1, mch[u] = 0;
        for (int &v : G[u]) {
            if (v == p)
                continue;
            dfs1(v, u);
            siz[u] += siz[v];
            if (siz[mch[u]] < siz[v])
                mch[u] = v;
        }
    }
    
    int cnt[1 << 20][17][2];
    ll ans;
    
    void calc1(int u, int p, int LCA) {
        int val = a[u] ^ a[LCA];
        for (int k = 16; k >= 0; --k) {
            int uk = (u >> k) & 1;
            ans += (1LL << k) * cnt[val][k][uk ^ 1];
        }
        for (int &v : G[u]) {
            if (v == p)
                continue;
            calc1(v, u, LCA);
        }
    }
    
    void calc2(int u, int p) {
        int val = a[u];
        for (int k = 16; k >= 0; --k) {
            int uk = (u >> k) & 1;
            ++cnt[val][k][uk];
        }
        for (int &v : G[u]) {
            if (v == p)
                continue;
            calc2(v, u);
        }
    }
    
    void calc3(int u, int p) {
        int val = a[u];
        memset(cnt[val], 0, sizeof(cnt[val]));
        for (int &v : G[u]) {
            if (v == p)
                continue;
            calc3(v, u);
        }
    }
    
    void dfs2(int u, int p, bool keep) {
        for (int &v : G[u]) {
            if (v == p || v == mch[u])
                continue;
            dfs2(v, u, false);
        }
        if (mch[u])
            dfs2(mch[u], u, true);
        int val = a[u];
        for (int k = 16; k >= 0; --k) {
            int uk = (u >> k) & 1;
            ans += (1LL << k) * cnt[0][k][uk ^ 1];
            ++cnt[val][k][uk];
        }
        for (int &v : G[u]) {
            if (v == p || v == mch[u])
                continue;
            calc1(v, u, u);
            calc2(v, u);
        }
        if (!keep) {
            memset(cnt[val], 0, sizeof(cnt[val]));
            for (int &v : G[u]) {
                if (v == p)
                    continue;
                calc3(v, u);
            }
        }
    }
    
    void solve() {
        scanf("%d", &n);
        for (int i = 1; i <= n; ++i)
            scanf("%d", &a[i]);
        for (int i = 1; i <= n; ++i)
            G[i].clear();
        for (int i = 1; i <= n - 1; ++i) {
            int u, v;
            scanf("%d%d", &u, &v);
            G[u].push_back(v);
            G[v].push_back(u);
        }
        dfs1(1, 0);
        dfs2(1, 0, true);
        printf("%lld
    ", ans);
    }
    
  • 相关阅读:
    构建之法阅读笔记07
    7-第一阶段SCRUM冲刺
    第一阶段个人冲刺博客第十天
    第一阶段个人冲刺博客第九天
    第九周学习进度博客
    java项目(学习和研究)
    让计算机干活
    os基础
    树和图的一些算法
    java代码理解
  • 原文地址:https://www.cnblogs.com/purinliang/p/14545169.html
Copyright © 2011-2022 走看看