zoukankan      html  css  js  c++  java
  • 洛谷P1600 天天爱跑步——树上差分

    题目:https://www.luogu.org/problemnew/show/P1600

    看博客:https://blog.csdn.net/clove_unique/article/details/53427248

    思路好神啊...

    树上差分是好东西。

    代码如下:

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    using namespace std;
    int const maxn=300005;
    int n,m,head[maxn],ct,w[maxn],tp,fp,ans[maxn],f[maxn][20],h[maxn],dfn[maxn],tim,a[maxn<<1];
    struct P{int t,pt,val;}tor[maxn<<2],fr[maxn<<2];
    struct N{
        int to,next;
        N(int t=0,int n=0):to(t),next(n) {}
    }edge[maxn<<1];
    void add(int x,int y){edge[++ct]=N(y,head[x]); head[x]=ct;}
    bool cmp(P x,P y){return dfn[x.pt]<dfn[y.pt];}
    void init(int x,int fa)
    {
        h[x]=h[fa]+1; f[x][0]=fa; dfn[x]=++tim;
        for(int i=1;i<=18;i++)
            f[x][i]=f[f[x][i-1]][i-1];
        for(int i=head[x],u;i;i=edge[i].next)
            if((u=edge[i].to)!=fa)init(u,x);
    }
    int lca(int x,int y)
    {
        if(h[x]<h[y])swap(x,y);
        int k=h[x]-h[y];
        for(int i=18;i>=0;i--)
            if(k&(1<<i))x=f[x][i];
        for(int i=18;i>=0;i--)
            if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
        if(x==y)return x;
        return f[x][0];
    }
    void dfs1(int x,int f)
    {
        int val=w[x]+h[x],dec=a[val];
        while(tp<=(m<<1) && tor[tp].pt==x) a[tor[tp].t + h[x]]+=tor[tp].val, tp++;
        for(int i=head[x],u;i;i=edge[i].next)
            if((u=edge[i].to)!=f)dfs1(u,x);
        ans[x]+=a[val]-dec;
    }
    void dfs2(int x,int f)
    {
        int val=w[x]-h[x]+1,dec=a[val];
        while(fp<=(m<<1) && fr[fp].pt==x) a[fr[fp].t]+=fr[fp].val, fp++;
        for(int i=head[x],u;i;i=edge[i].next)
            if((u=edge[i].to)!=f)dfs2(u,x);
        ans[x]+=a[val]-dec;
    }
    int main()
    {
        scanf("%d%d",&n,&m);
        for(int i=1,x,y;i<n;i++)
        {
            scanf("%d%d",&x,&y);
            add(x,y); add(y,x);
        }
        init(1,0);
        for(int i=1;i<=n;i++)scanf("%d",&w[i]);
        for(int i=1,s,t;i<=m;i++)
        {
            scanf("%d%d",&s,&t);
            int r=lca(s,t);
            tor[++tp].pt=s; tor[tp].t=0; tor[tp].val=1;
            if(r!=1)tor[++tp].pt=f[r][0], tor[tp].t=h[s]-h[r]+1, tor[tp].val=-1;
            fr[++fp].pt=t; fr[fp].t=h[s]-h[r]-h[r]+1; fr[fp].val=1;
            fr[++fp].pt=r; fr[fp].t=h[s]-h[r]-h[r]+1; fr[fp].val=-1;
        }
        
        sort(tor+1,tor+tp+1,cmp); tp=1;
        sort(fr+1,fr+fp+1,cmp); fp=1;
        dfs1(1,0);
        memset(a,0,sizeof a);
        dfs2(1,0);
        for(int i=1;i<=n;i++)printf("%d ",ans[i]);
        return 0;
    }
  • 相关阅读:
    Python中 sys.argv[]的用法简明解释
    Python--文件操作和集合
    Python--各种参数类型
    Python-- 字符串格式化 (%操作符)
    Python--绝对路径和相对路径
    Python3.x和Python2.x的区别
    关于ArrayList使用中的注意点
    Sun公司建议的Java类定义模板
    SWT组件之Table浅析
    mysql的相关命令整理(二)
  • 原文地址:https://www.cnblogs.com/Zinn/p/9216874.html
Copyright © 2011-2022 走看看