zoukankan      html  css  js  c++  java
  • BZOJ2599:[IOI2011]Race

    浅谈树分治:https://www.cnblogs.com/AKMer/p/10014803.html

    题目传送门:https://www.lydsy.com/JudgeOnline/problem.php?id=2599

    我们设(f_i)为长度为(i)的路径边数最小可以是多少,依次遍历当前根的子树,先用(cnt+f[k-dis])更新答案,再遍历第二遍当前子树更新(f)数组。(cnt)表示根到当前点一共经过了多少条边。因为(k)的范围是(10^6)级别的,每次处理当前联通块前把(f)数组全部赋成极大值会很慢,所以我们每次更新(f)数组的时候把被改动过的(dis)用栈记录下来,每次处理完当前联通块就弹栈并且把相应的(f)数组初始化,这样做就是(O(n))级别的了。

    如果用边分做的话,记得把新建的边权值赋成(-1),因为可能会有边权为(0)的边,然后统计(cnt)的时候只有在碰到权值不为(-1)的边才算。

    时间复杂度:(O(nlogn))

    空间复杂度:(O(n))

    点分治版代码如下:

    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    
    const int maxn=2e5+5,maxm=1e6+5;;
    
    bool vis[maxn],insta[maxm];
    int n,m,tot,mx,rt,N,ans,top;
    int siz[maxn],f[maxm],sta[maxn];
    int now[maxn],pre[maxn*2],son[maxn*2],val[maxn*2];
    
    int read() {
        int x=0,f=1;char ch=getchar();
        for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
        for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
        return x*f;
    }
    
    void add(int a,int b,int c) {
        pre[++tot]=now[a];
        now[a]=tot,son[tot]=b,val[tot]=c;
    }
    
    void find_rt(int fa,int u) {
        int res=0;siz[u]=1;
        for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
            if(!vis[v]&&v!=fa)find_rt(u,v),siz[u]+=siz[v],res=max(res,siz[v]);
        res=max(res,N-siz[u]);
        if(res<mx)mx=res,rt=u;
    }
    
    void query(int fa,int u,int cnt,int dis) {
        if(dis<=m)ans=min(ans,cnt+f[m-dis]);siz[u]=1;
        for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
            if(!vis[v]&&v!=fa)query(u,v,cnt+1,dis+val[p]),siz[u]+=siz[v];
    }
    
    void solve(int fa,int u,int cnt,int dis) {
        if(dis>m)return;
        f[dis]=min(f[dis],cnt);
        if(!insta[dis])sta[++top]=dis,insta[dis]=1;
        for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
            if(!vis[v]&&v!=fa)solve(u,v,cnt+1,dis+val[p]);
    }
    
    void work(int u,int size) {
        N=size,mx=rt=n+1,find_rt(0,u),u=rt,vis[u]=1;
        for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
            if(!vis[v])query(u,v,1,val[p]),solve(u,v,1,val[p]);
        ans=min(ans,f[m]);
        while(top) {
            f[sta[top]]=f[m+1];
            insta[sta[top--]]=0;
        }
        for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
            if(!vis[v])work(v,siz[v]);
    }
    
    int main() {
        n=read(),m=read();
        for(int i=1;i<n;i++) {
            int a=read()+1,b=read()+1,c=read();
            add(a,b,c),add(b,a,c);
        }
        memset(f,127/3,sizeof(f));
        ans=f[0];work(1,n);
        if(ans==f[m+1])puts("-1");
        else printf("%d
    ",ans);
        return 0;
    }
    

    边分治版代码如下:

    #include <cmath>
    #include <cstdio>
    #include <vector>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    typedef pair<int,int> pii;
    #define fr first
    #define sc second
    
    const int maxn=4e5+5,maxm=1e6+5;
    
    bool vis[maxn],insta[maxm];
    int f[maxm],siz[maxn],sta[maxn];
    int n,m,tot,cnt,mx,id,N,top,ans,tmp1,tmp2;
    int now[maxn],pre[maxn*2],son[maxn*2],val[maxn*2];
    
    vector<pii>to[maxn];
    vector<pii>::iterator it;
    
    int read() {
        int x=0,f=1;char ch=getchar();
        for(;ch<'0'||ch>'9';ch=getchar())if(ch=='-')f=-1;
        for(;ch>='0'&&ch<='9';ch=getchar())x=x*10+ch-'0';
        return x*f;
    }
    
    void add(int a,int b,int c) {
        pre[++tot]=now[a];
        now[a]=tot,son[tot]=b,val[tot]=c;
    }
    
    void find_son(int fa,int u) {
        for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
            if(v!=fa)to[u].push_back(make_pair(v,val[p])),find_son(u,v);
    }
    
    void rebuild() {
        tot=1;memset(now,0,sizeof(now));
        for(int i=1;i<=cnt;i++) {
            int size=to[i].size();
            if(size<=2) {
                for(it=to[i].begin();it!=to[i].end();it++) {
                    pii tmp=*it;
                    add(i,tmp.fr,tmp.sc),add(tmp.fr,i,tmp.sc);
                }
            }
            else {
                pii u1=make_pair(++cnt,-1),u2;
                if(size==3)u2=to[i].front();
                else u2=make_pair(++cnt,-1);
                add(i,u1.fr,u1.sc),add(u1.fr,i,u1.sc);
                add(i,u2.fr,u2.sc),add(u2.fr,i,u2.sc);
                if(size==3) {
                    for(int j=1;j<=2;j++)
                        to[cnt].push_back(to[i].back()),to[i].pop_back();
                }
                else {
                    int p=0;
                    for(it=to[i].begin();it!=to[i].end();it++) {
                        if(!p)to[cnt-1].push_back(*it);
                        else to[cnt].push_back(*it);p^=1;
                    }
                }
            }
        }
    }
    
    void find_edge(int fa,int u) {
        siz[u]=1;
        for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
            if(!vis[p>>1]&&v!=fa) {
                find_edge(u,v),siz[u]+=siz[v];
                if(abs(N-2*siz[v])<mx)mx=abs(N-2*siz[v]),id=p>>1;
            }
    }
    
    
    void solve(int fa,int u,int num,int dis) {
        if(dis<=m) {
            f[dis]=min(f[dis],num);
            if(!insta[dis])sta[++top]=dis,insta[dis]=1;
        }siz[u]=1;
        for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
            if(!vis[p>>1]&&v!=fa)
    			solve(u,v,num+(val[p]!=-1),dis+(val[p]!=-1)*val[p]),siz[u]+=siz[v];
    }
    
    void query(int fa,int u,int num,int dis) {
        if(dis+tmp1<=m)ans=min(ans,num+f[m-dis-tmp1]+tmp2);siz[u]=1;
        for(int p=now[u],v=son[p];p;p=pre[p],v=son[p])
            if(!vis[p>>1]&&v!=fa)
    			query(u,v,num+(val[p]!=-1),dis+(val[p]!=-1)*val[p]),siz[u]+=siz[v];
    }
    
    void work(int u,int size) {
        N=size,mx=id=cnt+1,find_edge(0,u);
        if(id==cnt+1)return;vis[id]=1;
        int u1=son[id<<1],u2=son[id<<1|1];
        tmp1=(val[id<<1]!=-1)*val[id<<1],tmp2=(val[id<<1]!=-1);
        solve(0,u1,0,0),query(0,u2,0,0);
        while(top) {
            f[sta[top]]=f[m+1];
            insta[sta[top--]]=0;
        }
        work(u1,siz[u1]),work(u2,siz[u2]);
    }
    
    int main() {
        cnt=n=read(),m=read();
        for(int i=1;i<n;i++) {
            int a=read()+1,b=read()+1,c=read();
            add(a,b,c),add(b,a,c);
        }
        find_son(0,1),rebuild();
        memset(f,127/3,sizeof(f));
        ans=f[0];work(1,cnt);
        if(ans==f[m+1])puts("-1");
        else printf("%d
    ",ans);
        return 0;
    }
    
  • 相关阅读:
    网络基础
    SQL注入
    OpenID说明
    Linux网络编程
    Linux的僵尸进程产生原因及解决方法
    计算机系统的存储层次
    Java实现SSO
    JD(转载)
    Switch的表达式的要求
    leetcode(23. Merge k Sorted Lists)
  • 原文地址:https://www.cnblogs.com/AKMer/p/10052276.html
Copyright © 2011-2022 走看看