zoukankan      html  css  js  c++  java
  • 树上问题



    // luogu-judger-enable-o2
    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cstring>
    #include<vector>
    #include<cstdlib>
    #include<cmath>
    using namespace std;
    const int N  = 3e5+1e3;
    int read(){
        int q=0;char ch=' ';
        while(ch<'0'||ch>'9')ch=getchar();
        while(ch>='0'&&ch<='9')q=q*10+ch-'0',ch=getchar();
        return q;
    }
    struct Edge{int next,to;}e[N<<1];
    int n,k,last[N],edge_number;
    int size[N],fa[N],id[N],low[N],q[N],tot,d[N],top[N],kkk[N];
    vector<pair<int,int > >v1[N],v2[N];
    int cov[N<<2];
    pair<int,int> sum[N<<2];
    void add(int a,int b){
        e[++edge_number].next=last[a],last[a]=edge_number;e[edge_number].to=b;
    }
    void dfs1(int x,int f){
        size[x]=1;fa[x]=f;d[x]=d[f]+1;
        for(int i=last[x];i;i=e[i].next){
            if(e[i].to==f) continue;
            dfs1(e[i].to,x);
            size[x]+=size[e[i].to];
            if(size[e[i].to]>size[kkk[x]]) kkk[x]=e[i].to;
        }
    }
    void dfs2(int x,int topf){
        id[x]=++tot,q[tot]=x;
        top[x]=topf;
        if(kkk[x]) dfs2(kkk[x],topf);
        for(int i=last[x];i;i=e[i].next){
            if(e[i].to==fa[x]||e[i].to==kkk[x]) continue;
            dfs2(e[i].to,e[i].to);
        }
        low[x]=tot;
    }
    int lca(int a,int b){
        while(top[a]!=top[b]){
            if(d[top[a]]<d[top[b]]) swap(a,b);
            a=fa[top[a]];
        }
        return d[a]>d[b]?b:a; 
    }
    void insert(int a,int b,int c,int d){
        v1[a].push_back(make_pair(c,d));
        v2[b+1].push_back(make_pair(c,d));
    }
    int getanc(int x,int ddd){
        while(d[top[x]]>ddd+1)
            {x=fa[top[x]];}
        return q[id[top[x]]+ddd+1-d[top[x]]];
    }
    void deal(int a,int b){
        int z=lca(a,b);
        if(z!=a&&z!=b){
            insert(id[a],low[a],id[b],low[b]);
            insert(id[b],low[b],id[a],low[a]);
        }
        else{
            if(b==z) std::swap(a,b);
            a=getanc(b,d[a]);
            if(id[a]>1){
                insert(1,id[a]-1,id[b],low[b]);
                insert(id[b],low[b],1,id[a]-1); 
            }
            if(low[a]<n){
                insert(low[a]+1,n,id[b],low[b]);
                insert(id[b],low[b],low[a]+1,n);
            }
        }
    }
    void pushup(int cur){
        if(sum[cur<<1].first==sum[cur<<1|1].first){
            sum[cur].first=sum[cur<<1].first;
            sum[cur].second=sum[cur<<1].second+sum[cur<<1|1].second;
        }
        else{
            sum[cur]=max(sum[cur<<1],sum[cur<<1|1]);
        }
    }
    void build(int cur,int l,int r){
        if(l==r){
            sum[cur].first=1;
            sum[cur].second=1;
            return;
        }
        int mid=(l+r)>>1;
        build(cur<<1,l,mid);
        build(cur<<1|1,mid+1,r);
        pushup(cur);
    }
    void update(int cur,int v){
        cov[cur]+=v;
        sum[cur].first+=v;
    }
    void pushdown(int cur){
        if(cov[cur]!=0){
            update(cur<<1,cov[cur]);
            update(cur<<1|1,cov[cur]);
            cov[cur]=0;
        }
    }
    
    void add(int cur,int l,int r,int L,int R,int v){
        if(L<=l&&r<=R){
            sum[cur].first+=v;
            cov[cur]+=v;
            return ;
        }
        pushdown(cur);
        int mid=(l+r)>>1;
        if(L<=mid) add(cur<<1,l,mid,L,R,v);
        if(R>mid) add(cur<<1|1,mid+1,r,L,R,v);
        pushup(cur);
    }
    int main(){
        n=read(),k=read();
        for(int i=1;i<n;++i){
            int a=read(),b=read();
            add(a,b),add(b,a);
        }
        dfs1(1,0);
        dfs2(1,1);
        for(int i=1;i<=n;++i){
            for(int j=i+1;j<=std::min(n,i+k);++j){
                deal(i,j);
            }
        }
        build(1,1,n);
        long long ans=0;
        for(int i=1;i<=n;++i){
            int S=v1[i].size();
            for(int j=0;j<S;++j){
                add(1,1,n,v1[i][j].first,v1[i][j].second,-1);
            }
            S=v2[i].size();
            for(int j=0;j<S;++j){
                add(1,1,n,v2[i][j].first,v2[i][j].second,1);
            }
            if(sum[1].first==1){
                ans+=sum[1].second;
            }
        }
        printf("%lld",(ans+n)/2);
    }
    
  • 相关阅读:
    Linux 服务器连接远程数据库(Mysql、Pgsql)
    oracle主键自增
    全排列算法实现
    python动态导入包
    python发红包实现
    CentOS 6.8安装Oracle 11 g 解决xhost: unable to open display
    xargs的一个小坑
    利用ssh-copy-id复制公钥到多台服务器
    redhat 5 更换yum源
    【原创】Hadoop的IO模型(数据序列化,文件压缩)
  • 原文地址:https://www.cnblogs.com/fengxunling/p/9789440.html
Copyright © 2011-2022 走看看