zoukankan      html  css  js  c++  java
  • 【BZOJ】3626 [LNOI2014]LCA

    【算法】树链剖分+线段树(区间加值,区间求和)

    【题解】http://hzwer.com/3891.html

    中间不要取模不然相减会出错。

    血的教训:线段树修改时标记下传+上传,查询时下传。如果修改时标记不下传,下面的结果就会覆盖上面的标记上传造成的影响。

    读入后全部排序(离线处理)

    链剖之后按顺序每个solve_insert(1,j),对于每次的z询问solve_sum(1,z)。

    LCA其实就是两点到达根节点的路径的最近交点。

    差分思想的运用:将区间差转为r-(l-1)。

    #include<cstdio>
    #include<algorithm>
    #include<cstring>
    using namespace std;
    const int maxn=50010,MOD=201314;
    int n,first[maxn],tot=0,top[maxn],deep[maxn],pos[maxn],q,ansz[maxn],size[maxn],f[maxn],dfsnum=0;
    long long anss[maxn*3];
    struct edge{int from,v;}e[maxn*3];
    struct node{int l,r,delta,sum;}t[maxn*3];
    struct numbers{int num,ord;bool flag;}num[maxn*3];
    void insert(int u,int v)
    {tot++;e[tot].v=v;e[tot].from=first[u];first[u]=tot;}
    bool cmp(numbers a,numbers b)
    {return a.num<b.num;}
    void dfs1(int x,int fa)
    {
        size[x]=1;
        for(int i=first[x];i;i=e[i].from)
         if(e[i].v!=fa)
          {
              int y=e[i].v;
              f[y]=x;
              deep[y]=deep[x]+1;
              dfs1(y,x);
              size[x]+=size[y];
          }
    }
    void dfs2(int x,int tp,int fa)
    {
        pos[x]=++dfsnum;
        top[x]=tp;
        int k=0;
        for(int i=first[x];i;i=e[i].from)
         if(e[i].v!=fa&&size[e[i].v]>size[k])k=e[i].v;
        if(k==0)return;
        dfs2(k,tp,x);
        for(int i=first[x];i;i=e[i].from)
         if(e[i].v!=fa&&e[i].v!=k)dfs2(e[i].v,e[i].v,x);
        
    }
    void build(int k,int l,int r)
    {
        t[k].l=l;t[k].r=r;t[k].delta=0;t[k].sum=0;
        if(l==r)return;
        int mid=(l+r)>>1;
        build(k<<1,l,mid);
        build(k<<1|1,mid+1,r);
        t[k].sum=t[k<<1].sum+t[k<<1|1].sum;
    }
    void add(int k,int l,int r)
    {
        int left=t[k].l,right=t[k].r;
        if(l<=left&&right<=r)
         {
             t[k].delta++;
             t[k].sum+=right-left+1;
             return;
         }
        if(t[k].delta)
         {
             t[k<<1].delta+=t[k].delta;
             t[k<<1].sum+=(t[k<<1].r-t[k<<1].l+1)*t[k].delta;
             t[k<<1|1].delta+=t[k].delta;
             t[k<<1|1].sum+=(t[k<<1|1].r-t[k<<1|1].l+1)*t[k].delta;
             t[k].delta=0;
         }
        int mid=(left+right)>>1;
        if(l<=mid)add(k<<1,l,r);
        if(r>mid)add(k<<1|1,l,r);
        t[k].sum=t[k<<1].sum+t[k<<1|1].sum;
    }
    long long query(int k,int l,int r)
    {
        int left=t[k].l,right=t[k].r;
        if(l<=left&&right<=r)return t[k].sum;
        if(t[k].delta)
         {
             t[k<<1].delta+=t[k].delta;
             t[k<<1].sum+=(t[k<<1].r-t[k<<1].l+1)*t[k].delta;
             t[k<<1|1].delta+=t[k].delta;
             t[k<<1|1].sum+=(t[k<<1|1].r-t[k<<1|1].l+1)*t[k].delta;
             t[k].delta=0;
         }
        int mid=(left+right)>>1;
        long long ans=0;
        if(l<=mid)ans=query(k<<1,l,r);
        if(r>mid)ans+=query(k<<1|1,l,r);
        return ans;
    }
    void solve_ins(int x,int y)
    {
        while(top[x]!=top[y])
         {
             if(deep[top[x]]<deep[top[y]])swap(x,y);
             add(1,pos[top[x]],pos[x]);
             x=f[top[x]];
         }
        if(pos[x]>pos[y])swap(x,y);
        add(1,pos[x],pos[y]);
    }
    long long solve_sum(int x,int y)
    {
        long long ans=0;
        while(top[x]!=top[y])
         {
             if(deep[top[x]]<deep[top[y]])swap(x,y);
             ans+=query(1,pos[top[x]],pos[x]);
             x=f[top[x]];
         }
        if(pos[x]>pos[y])swap(x,y);
        ans+=query(1,pos[x],pos[y]);
        return ans;
    }
    int main()
    {
        scanf("%d%d",&n,&q);
        int u;
        for(int i=2;i<=n;i++)
         {
             scanf("%d",&u);
             insert(u+1,i);
             insert(i,u+1);
         }
        int ll,rr,zz;
        for(int i=1;i<=q;i++)
         {
             scanf("%d%d%d",&ll,&rr,&zz);ll++;rr++;zz++;
             ansz[i]=zz;
             num[i*2-1].num=ll-1;num[i*2-1].ord=i;num[i*2-1].flag=0;
             num[i*2].num=rr;num[i*2].ord=i;num[i*2].flag=1;
         }
        sort(num+1,num+q*2+1,cmp);
        build(1,1,n);dfs1(1,-1);dfs2(1,1,-1);
        int now=0;
        memset(anss,0,sizeof(anss));
        for(int i=1;i<=q*2;i++)
         {
             if(num[i].num>now)
              for(int j=now+1;j<=num[i].num;j++)solve_ins(1,j);
             now=num[i].num;
             if(num[i].flag)anss[num[i].ord]+=solve_sum(1,ansz[num[i].ord]);
              else anss[num[i].ord]-=solve_sum(1,ansz[num[i].ord]);
         }
        
        for(int i=1;i<=q;i++)printf("%lld
    ",anss[i]%MOD);
        return 0;
    }
    View Code
  • 相关阅读:
    面向对象课程第三次博客总结
    面向对象课程多线程总结
    23种设计模式整理
    java中synchronized与lock的理解与应用
    关于MySQL查询优化
    mysql操作规范建议
    Linux中实体链接与符号链接详解
    获取本地ipv4地址方法(go语言)
    分库分表与负载均衡的一致性hash算法
    golang闭包的一个经典例子
  • 原文地址:https://www.cnblogs.com/onioncyc/p/6750363.html
Copyright © 2011-2022 走看看