zoukankan      html  css  js  c++  java
  • 2019ICPC上海F A Simple Problem On A Tree(树链剖分)

    这道题明显就是告诉你就是树链剖分+线段树维护三次方和,那么显然就是拆项后发现维护一次方和,二次方和和三次方和

    这里涉及到两个操作,一个是add一个是mul

    因此我们要考虑优先级,这是洛谷的线段树模板2,要先mul再add,因为这样可以解决先加后乘的问题

    #include<bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int N=2e5;
    const int mod=1e9+7;
    int h[N],ne[N],e[N],idx;
    int son[N],pre[N],id[N],sz[N],fa[N];
    int n;
    int depth[N],top[N],times;
    ll w[N];
    struct node{
        int l,r;
        ll mul;
        ll ad;
        ll sum1;
        ll sum2;
        ll sum3;
    }tr[N<<2];
    void add(int a,int b){
        e[idx]=b,ne[idx]=h[a],h[a]=idx++;
    }
    void dfs(int u){
        int i;
        sz[u]=1;
        for(i=h[u];i!=-1;i=ne[i]){
            int j=e[i];
            if(j==fa[u])
                continue;
            fa[j]=u;
            depth[j]=depth[u]+1;
            dfs(j);
            sz[u]+=sz[j];
            if(sz[j]>sz[son[u]]){
                son[u]=j;
            }
        }
    }
    void dfs1(int u,int x){
        pre[u]=++times;
        id[times]=u;
        top[u]=x;
        if(!son[u])
            return;
        dfs1(son[u],x);
        int i;
        for(i=h[u];i!=-1;i=ne[i]){
            int j=e[i];
            if(j==fa[u]||j==son[u])
                continue;
            dfs1(j,j);
        }
    }
    void pushup(int u){
        tr[u].sum1=(tr[u<<1].sum1+tr[u<<1|1].sum1)%mod;
        tr[u].sum2=(tr[u<<1].sum2+tr[u<<1|1].sum2)%mod;
        tr[u].sum3=(tr[u<<1].sum3+tr[u<<1|1].sum3)%mod;
    }
    void build(int u,int l,int r){
        if(l==r){
            tr[u]={l,r,1,0,w[id[l]],w[id[l]]*w[id[l]]%mod,w[id[l]]*w[id[l]]%mod*w[id[l]]%mod};
        }
        else{
            tr[u]={l,r,1,0,0,0,0};
            int mid=l+r>>1;
            build(u<<1,l,mid);
            build(u<<1|1,mid+1,r);
            pushup(u);
        }
    }
    void down(int u,ll x,ll y){
        if(y!=1){
            tr[u].sum3=(tr[u].sum3*y%mod*y%mod*y)%mod;
            tr[u].sum2=(tr[u].sum2*y%mod*y)%mod;
            tr[u].sum1=(tr[u].sum1*y%mod)%mod;
            tr[u].mul=tr[u].mul*y%mod;
            tr[u].ad=tr[u].ad*y%mod;
        }
        if(x!=0){
            tr[u].sum3=(tr[u].sum3+3ll*x*tr[u].sum2+3*x%mod*x%mod*tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod*x%mod*x)%mod;
            tr[u].sum2=(tr[u].sum2+(tr[u].r-tr[u].l+1)*x%mod*x+2*tr[u].sum1*x)%mod;
            tr[u].sum1=(tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod)%mod;
            tr[u].ad=(tr[u].ad+x)%mod;
        }
    }
    void pushdown(int u){
        ll y=tr[u].mul,x=tr[u].ad;
        down(u<<1,x,y);
        down(u<<1|1,x,y);
        tr[u].mul=1;
        tr[u].ad=0;
    }
    void modify(int u,int l,int r,ll x,int opt){
        if(tr[u].l>=l&&tr[u].r<=r){
            if(opt==1){
                tr[u].sum1=(tr[u].r-tr[u].l+1)*x%mod;
                tr[u].sum2=(tr[u].r-tr[u].l+1)*x%mod*x%mod;
                tr[u].sum3=(tr[u].r-tr[u].l+1)*x%mod*x%mod*x%mod;
                tr[u].mul=0;
                tr[u].ad=x;
            }
            else if(opt==2){
                tr[u].sum3=(tr[u].sum3+3ll*x*tr[u].sum2+3*x%mod*x%mod*tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod*x%mod*x)%mod;
                tr[u].sum2=(tr[u].sum2+(tr[u].r-tr[u].l+1)*x%mod*x+2*tr[u].sum1*x)%mod;
                tr[u].sum1=(tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod)%mod;
                tr[u].ad=(tr[u].ad+x)%mod;
            }
            else if(opt==3){
                tr[u].sum3=(tr[u].sum3*x%mod*x%mod*x)%mod;
                tr[u].sum2=(tr[u].sum2*x%mod*x)%mod;
                tr[u].sum1=(tr[u].sum1*x%mod)%mod;
                tr[u].mul=tr[u].mul*x%mod;
                tr[u].ad=(tr[u].ad*x)%mod;
            }
            return ;
        }
        pushdown(u);
        int mid=tr[u].l+tr[u].r>>1;
        if(l<=mid)
            modify(u<<1,l,r,x,opt);
        if(r>mid)
            modify(u<<1|1,l,r,x,opt);
        pushup(u);
    }
    void change(int x,int y,ll z,int opt){
        while(top[x]!=top[y]){
            if(depth[top[x]]<depth[top[y]])
                swap(x,y);
            modify(1,pre[top[x]],pre[x],z,opt);
            x=fa[top[x]];
        }
        if(depth[x]>depth[y])
            swap(x,y);
        modify(1,pre[x],pre[y],z,opt);
    }
    ll query(int u,int l,int r){
        if(tr[u].l>=l&&tr[u].r<=r){
            return tr[u].sum3;
        }
        pushdown(u);
        int mid=tr[u].l+tr[u].r>>1;
        ll ans=0;
        if(l<=mid)
            ans+=query(u<<1,l,r);
        ans%=mod;
        if(r>mid)
            ans=(ans+query(u<<1|1,l,r))%mod;
        return ans;
    }
    ll qpath(int x,int y){
        ll res=0;
        while(top[x]!=top[y]){
            if(depth[top[x]]<depth[top[y]])
                swap(x,y);
            res=(res+query(1,pre[top[x]],pre[x]))%mod;
            x=fa[top[x]];
        }
        if(depth[x]>depth[y])
            swap(x,y);
        res=res+query(1,pre[x],pre[y]);
        res%=mod;
        return res;
    }
    int main(){
        //ios::sync_with_stdio(false);
        int cas=0;
        int t;
        cin>>t;
        while(t--){
            idx=0;
            scanf("%d",&n);
            memset(h,-1,sizeof h);
            memset(sz,0,sizeof sz);
            memset(son,0,sizeof son);
            memset(depth,0,sizeof depth);
            memset(id,0,sizeof id);
            memset(fa,0,sizeof fa);
            memset(top,0,sizeof top);
            times=0;
            int i;
            printf("Case #%d: 
    ",++cas);
            for(i=1;i<n;i++){
                int a,b;
                scanf("%d%d",&a,&b);
                add(a,b);
                add(b,a);
            }
            for(i=1;i<=n;i++)
                scanf("%lld",&w[i]);
            depth[1]=1;
            fa[1]=0;
            dfs(1);
            dfs1(1,1);
            build(1,1,n);
            int q;
            scanf("%d",&q);
            while(q--){
                int opt;
                scanf("%d",&opt);
                ll u,v,w;
                if(opt==1){
                    scanf("%lld%lld%lld",&u,&v,&w);
                    change(u,v,w,1);
                }
                else if(opt==2){
                    scanf("%lld%lld%lld",&u,&v,&w);
                    change(u,v,w,2);
                }
                else if(opt==3){
                    scanf("%lld%lld%lld",&u,&v,&w);
                    change(u,v,w,3);
                }
                else{
                    scanf("%lld%lld",&u,&v);
                    printf("%lld
    ",qpath(u,v)%mod);
                }
            }
        }
        return 0;
    }
    View Code
    没有人不辛苦,只有人不喊疼
  • 相关阅读:
    深入浅出数据库索引原理
    Mysql读写分离原理及主众同步延时如何解决
    数据库连接池实现原理
    MySQL 大表优化方案(长文)
    js-ajax-03
    js-ajax-04
    js-ajax-02
    js-ajax-01
    获取html对象方式
    js-事件总结
  • 原文地址:https://www.cnblogs.com/ctyakwf/p/14076709.html
Copyright © 2011-2022 走看看