zoukankan      html  css  js  c++  java
  • 点分治

    题意

    后两个求和符号代表的是有多少异或值为0的路径。
    前两个符号代表有多少路径包含异或值为0的路径。
    即每个权值为0的路径对答案的贡献为 有多少路径包含当前路径 ,所有的贡献加起来就是答案。


    思路

    点分治,权值太大,桶只能用map了。
    x与y之间的路径(x与y间权值异或为0)对答案的贡献是
    (x可以扩展的点数)**(y可以扩展的点数)。
    1
    如上图,设2与3之间的权值异或为0,那么这条边对答案的贡献为2*3(2可以扩展到2、4两个点,3可以扩展到3、5、6这三个点)。

    桶里面加上的不是个数了,而是对于 当前重心 到 x 的 这一路径 ,x可以扩展到几个点。如果知道对于每条路径上面的端点可以扩展到多少点这题就可做了。

    预处理:让a作为根,遍历一遍,统计每个点子树的大小siz[x],与每个点的上一个点init_pre[x]。
    现在要求rt->…..->pre->x这一路径x可以扩展多少:
    1、init_pre[x]==pre,那么x可以扩展的点就是siz[x].
    2、init_pre[x]!=pre,那么x可以扩展的点是n-siz[pre]
    可以o(1)求出来。

    解题思路:

    树上路径计数问题,想当然就是点分治了,但是这道题并没有那么简单。

    以这个图为例:

    首先是路径两端点数的计算问题:

    假设我们现在点分治的过程中以3为重心,3-4和2-3都为2,因此2-4这个路径的异或和为0,我们期望用2号节点及其上面的点的数量乘以4号节点以及其下面的点的数量来得到2-4这条路径的贡献。但是点分治里的sz数组是不固定且不正确的,会随着重心的不同发生变化,所以我们用最初第一次以1为根获得的siz进行处理,在第一次getrt的过程中记录每个节点的前驱节点,如果在这次遍历过程中前驱节点与第一次getrt过程中的前驱节点相同的话,代表当前方向与第一次getrt时的方向相同,就可以直接用第一次getrt处理到的siz,否则说明方向相反,就要用总的点数 n 减去当前节点前驱节点的siz获得结果,比如2号节点及其上方的点数就可以用总点数5减去2的前驱3及其3下方点的数量得到。

    第二个问题是当前重心到某一点的异或和直接为零的情况,需要特殊处理:

    假设图中4-5的权值为2,我们依然假设3为当前重心,这样3-5的异或和为0,我们要单独去判断一下重心另一侧的点的数量再相乘(因为这一部分再遍历map的时候是没有处理的)

    #include<bits/stdc++.h>
    #include<>
    using namespace std;
    #define ll long long
    #define maxn 100100
    #define inf 0x3f3f3f3f
    #define mod 1000000007
    struct node
    {
        ll v,w,to;
    } edge[maxn*2];
    struct data
    {
        ll w,sum;
    }temp[maxn];
    bool vis[maxn];
    int head[maxn],cnt,n,rt,a,tot;
    ll sum,sz[maxn],maxx[maxn],ans,b,sz_rt;
    ll dis[maxn],siz[maxn],init_pre[maxn];
    unordered_map<ll,ll>mp;
    void init()
    {
        memset(head,-1,sizeof(head));
        memset(vis,0,sizeof(vis));
        cnt=ans=0;
    }
    void add(int u,int v,ll w)
    {
        edge[cnt]={v,w,head[u]};
        head[u]=cnt++;
        edge[cnt]={u,w,head[v]};
        head[v]=cnt++;
    }
    void dfs(int x,int pre)
    {
        siz[x]=1;
        init_pre[x]=pre;
        for(int i=head[x];i!=-1;i=edge[i].to)
        {
            int v=edge[i].v;
            if(v!=pre)
            {
                dfs(v,x);
                siz[x]+=siz[v];
            }
        }
    }
    void getrt(int x,int pre)
    {
        sz[x]=1;
        maxx[x]=0;
        for(int i=head[x];i!=-1;i=edge[i].to)
        {
            int v=edge[i].v;
            if(v!=pre&&!vis[v])
            {
                getrt(v,x);
                sz[x]+=sz[v];
                maxx[x]=max(maxx[x],sz[v]);
            }
        }
        maxx[x]=max(maxx[x],sum-sz[x]);
        if(maxx[x]<maxx[rt])rt=x;
    }
    void getdis(int x,int pre)
    {
        if(pre==init_pre[x])
        {
            ans=(ans+siz[x]*mp[dis[x]])%mod;
            temp[++tot]= {dis[x],siz[x]};
            if(dis[x]==0)
            ans=(ans+siz[x]*sz_rt)%mod;
        }
        else
        {
            ans=(ans+1ll*(n-siz[pre])*mp[dis[x]])%mod;
            temp[++tot]= {dis[x],n-siz[pre]};
            if(dis[x]==0)
            ans=(ans+1ll*(n-siz[pre])*sz_rt)%mod;
        }
        for(int i=head[x];i!=-1;i=edge[i].to)
        {
            int v=edge[i].v;
            ll w=edge[i].w;
            if(v==pre||vis[v])continue;
            dis[v]=dis[x]^w;
            getdis(v,x);
        }
    }
    void cal(int x)
    {
        dis[x]=0;
        for(int i=head[x];i!=-1;i=edge[i].to)
        {
            int v=edge[i].v;
            ll w=edge[i].w;
            if(vis[v])continue;
            if(init_pre[v]==x)sz_rt=n-siz[v];
            else sz_rt=siz[rt];
            dis[v]=w;
            tot=0;
            getdis(v,x);
            for(int j=1;j<=tot;j++)
            {
                mp[temp[j].w]+=temp[j].sum;
                mp[temp[j].w]%=mod;
            }
        }
        mp.clear();
    }
    void solve(int x)
    {
        vis[x]=1;
        cal(x);
        for(int i=head[x];i!=-1;i=edge[i].to)
        {
            int v=edge[i].v;
            if(vis[v])continue;
            maxx[rt=0]=inf;
            sum=sz[v];
            getrt(v,0);
            solve(rt);
        }
    }
    int main()
    {
        init();
        scanf("%d",&n);
        for(int i=2;i<=n;i++)
        {
            scanf("%d%lld",&a,&b);
            add(i,a,b);
        }
        dfs(1,0);
        maxx[rt=0]=inf;
        sum=n;
        getrt(1,0);
        solve(rt);
        printf("%lld
    ",ans%mod);
        return 0;
    }
    

      

  • 相关阅读:
    表达式与运算符
    Python3 从零单排22_异常处理
    Python3 从零单排21_元类
    Python3 从零单排20_方法(绑定&内置)&反射
    Python3 从零单排19_组合&多态
    Python3 从零单排18_封装
    Python3 从零单排17_类的继承
    Python3 从零单排16_面向对象基础
    Python3 从零单排15_urllib和requests模块
    Python3 从零单排14_flask模块&mysql操作封装
  • 原文地址:https://www.cnblogs.com/SDUTNING/p/10936244.html
Copyright © 2011-2022 走看看