zoukankan      html  css  js  c++  java
  • 洛谷 P2664 树上游戏

    lrb有一棵树,树的每个节点有个颜色。给一个长度为n的颜色序列,定义(s(i,j))(i)(j)的颜色数量。以及

    [sum_i = sum_{j=1}^{n}s(i,j) ]

    现在他想让你求出所有的(sum[i])

    这题真是难,点分治神题

    我们考虑一个性质,对于一个点(i),如果它的颜色在到根的路径中是第一次出现,那么对于和(i)不在一个子树的点(j),对(j)都有(i)的子树大小(size_i)的贡献

    然后有了这个性质,就好做了

    找完重心后预处理出来实际的(size),用(sum)来记录所有点的贡献,(s)是这个颜色的贡献

    而我们不是用点去更新答案,是用颜色来更新答案,所以要枚举子树,去掉这个子树的贡献来统计答案

    于是再有(X)表示除了这个子树的点数和,(co)表示这个点到根的颜色数

    然后记录下这个点到根的所有颜色的(s)的和,(s)是要被减去的

    那么(ans+=sum-s+co imes X),然后单独更新一下根就是(ans+=sum-s_{c_{rt}}+size_{rt})

    Code

    #include <iostream>
    #include <cstdio>
    #include <algorithm>
    #include <vector>
    const int N = 1e5;
    using namespace std;
    int n,c[N + 5],size[N + 5],maxp[N + 5],rt,su,vis[N + 5],cnt[N + 5];
    long long sum,s[N + 5],ros,X,ans[N + 5];
    vector <int> d[N + 5];
    void get_rt(int u,int fa)
    {
        size[u] = 1;
        maxp[u] = 0;
        vector <int>::iterator it;
        for (it = d[u].begin();it != d[u].end();it++)
        {
            int v = (*it);
            if (v == fa || vis[v])
                continue;
            get_rt(v,u);
            size[u] += size[v];
            maxp[u] = max(maxp[u],size[v]);
        }
        maxp[u] = max(maxp[u],su - size[u]);
        if (maxp[u] < maxp[rt])
            rt = u;
    }
    void get_size(int u,int fa)
    {
        size[u] = 1;
        vector <int>::iterator it;
        for (it = d[u].begin();it != d[u].end();it++)
        {
            int v = (*it);
            if (v == fa || vis[v])  
                continue;
            get_size(v,u);
            size[u] += size[v];
        }
    }
    void dfs(int u,int fa,int w)
    {
        cnt[c[u]]++;
        if (cnt[c[u]] == 1)
        {
            s[c[u]] += w * size[u];
            sum += w * size[u];
        }
        if (!cnt[c[rt]])
            ros += w;
        vector <int>::iterator it;
        for (it = d[u].begin();it != d[u].end();it++)
        {
            int v = (*it);
            if (v == fa || vis[v])
                continue;
            dfs(v,u,w);
        }
        cnt[c[u]]--;
    }
    void upd(int u,int fa,int co,int su)
    {
        cnt[c[u]]++;
        if (cnt[c[u]] == 1)
        {
            co++;
            su += s[c[u]];
        }
        ans[u] += sum - su + co * X;
        if (!cnt[c[rt]])
            ans[u] += ros;
        vector <int>::iterator it;
        for (it = d[u].begin();it != d[u].end();it++)
        {
            int v = (*it);
            if (v == fa || vis[v])  
                continue;
            upd(v,u,co,su);
        }
        cnt[c[u]]--;
    }
    void calc(int u)
    {
        vector <int>::iterator it;
        for (it = d[u].begin();it != d[u].end();it++)
        {
            int v = (*it);
            if (vis[v])
                continue;
            dfs(v,u,1);
        }
        for (it = d[u].begin();it != d[u].end();it++)
        {
            int v = (*it);
            if (vis[v])
                continue;
            dfs(v,u,-1);
            X = size[u] - size[v];
            upd(v,0,0,0);
            dfs(v,u,1);
        }
        ans[u] += sum - s[c[u]] + size[u];
        for (it = d[u].begin();it != d[u].end();it++)
        {
            int v = (*it);
            if (vis[v])
                continue;
            dfs(v,u,-1);
        }
    }
    void solve(int u)
    {
        vis[u] = 1;
        ros = 1;
        get_size(u,0);
        calc(u);
        vector <int>::iterator it;
        for (it = d[u].begin();it != d[u].end();it++)
        {
            int v = (*it);
            if (vis[v])
                continue;
            maxp[0] = N;
            su = size[v];
            rt = 0;
            get_rt(v,0);
            solve(rt);
        }
    }
    int main()
    {
        scanf("%d",&n);
        for (int i = 1;i <= n;i++)
            scanf("%d",&c[i]);
        int u,v;
        for (int i = 1;i < n;i++)
        {
            scanf("%d%d",&u,&v);
            d[u].push_back(v);
            d[v].push_back(u);
        }
        su = n;
        maxp[0] = N;
        get_rt(1,0);
        get_size(rt,0);
        solve(rt);
        for (int i = 1;i <= n;i++)
            printf("%lld
    ",ans[i]);
        return 0;
    }
    
  • 相关阅读:
    Docker跨平台架构的新特性buildx的启用方式
    Linux 如何安装rvm和ruby
    Linux
    ubuntu安装 vmware workstation pro 15.1.1
    docker-compose搭建golang本地开发环境
    linux 常用命令
    leetcode 1046 最后一块石头的重量
    leetcode 330 按要求补齐数组
    MySQL 字符集与比较规则
    Python 是什么语言
  • 原文地址:https://www.cnblogs.com/sdlang/p/13068202.html
Copyright © 2011-2022 走看看