zoukankan      html  css  js  c++  java
  • 『ZJOI2019 D2T2』语言

    ~~ 话说,本题考场想出三只(log)的暴力做法,被卡成暴力了。~~

    题目分析

    首先考虑枚举每一个点,计算这个点可以和多少点进行交易。

    将所有经过该点的路径(s,t)拿出,那么这些极远的(s,t)构成的连通块大小(sz - 1)就是答案。

    (Codeforces)(异象石)那题可以想到,若一些点集按照(dfs)序排序,那么这些点构成连通块大小就是

    (frac{1}{2} (dist(a_1 , a_2) + dist(a_2,a_3) + ... + dist(a_{k-1} , a_k) + dist(a_k,a_1)))

    考虑对于每一个节点开一棵线段树,其叶子节点(i)表示(dfs)序为(i)的极远点出现次数。

    线段树中存储(3)个值(lp,rp,Sum)分别表示当前存在的大于(0)的最小下标和最大下标,和不算头尾的连通块大小。

    由于路径条数为(m),显然我们可以用可持久化线段树来维护这(n)棵线段树,使得空间复杂度为(O(m log_2 n))

    利用树上差分的思想,对于每一条(s,t)的路径,我们先在(s)(t)所在的线段树中将(dfn[s])(dfn[t])两个点单点(+1)

    然后在(father(lca(s,t)))的节点,将(dfn[s])(dfn[t])两个点单点(-2)

    于是,我们可以自下往上去统计每个节点的答案。

    每一次,我们需要对该节点的所有子树进行线段树合并,然后询问这个节点的答案,将其累加进总个数中。

    这样,我们就完成了无序数对的统计,那么此时答案除以(2)就是最终的答案。

    复杂度分析

    由于(n)次线段树合并节点总数是(m)个,所以需要时间复杂度为(O(m log_2 n))

    由于(m)次线段单点修改,使用(O(1))(LCA)实现,所以需要时间复杂度为(O(m log_2 n))

    所以,本题的总时间复杂度就是(O(m log_2 n))

    # include<bits/stdc++.h>
    # define int long long
    # define inf (1e9)
    using namespace std;
    const int N=1e5+10;
    struct rec{ int pre,to;}a[N<<1];
    int dep[N],head[N],dfn[N],root[N],acr[N],g[N];
    int n,m,tot,ans;
    namespace fast_IO{
        const int IN_LEN = 10000000, OUT_LEN = 10000000;
        char ibuf[IN_LEN], obuf[OUT_LEN], *ih = ibuf + IN_LEN, *oh = obuf, *lastin = ibuf + IN_LEN, *lastout = obuf + OUT_LEN - 1;
        inline char getchar_(){return (ih == lastin) && (lastin = (ih = ibuf) + fread(ibuf, 1, IN_LEN, stdin), ih == lastin) ? EOF : *ih++;}
        inline void putchar_(const char x){if(oh == lastout) fwrite(obuf, 1, oh - obuf, stdout), oh = obuf; *oh ++= x;}
        inline void flush(){fwrite(obuf, 1, oh - obuf, stdout);}
        int read(){
            int x = 0; int zf = 1; char ch = ' ';
            while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar_();
            if (ch == '-') zf = -1, ch = getchar_();
            while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar_(); return x * zf;
        }
        void write(int x){
            if (x < 0) putchar_('-'), x = -x;
            if (x > 9) write(x / 10);
            putchar_(x % 10 + '0');
        }
    }
    using namespace fast_IO;
    namespace LCA {
        int ST[N << 1][22], value[N << 1], depth[N << 1], first[N], dist[N], cnt;
        inline int calc(int x, int y) {
            return depth[x] < depth[y] ? x : y;
        }
        inline void dfs(int u, int p, int d) {
            value[++cnt] = u; depth[cnt] = d; first[u] = cnt;
            for (int i = head[u]; i; i = a[i].pre) {
                int v = a[i].to;
                if (v == p) continue;
                dist[v] = dist[u] + 1;
                dfs(v, u, d + 1);
                value[++cnt] = u; depth[cnt] = d;
            }
        }
        inline void init(int root, int node_cnt) {
        	cnt = 0; dist[root] = 0;
            dfs(root, 0, 1);
            int n = 2 * node_cnt - 1;
            for (int i = 1; i <= n; i++) ST[i][0] = i;
            for (int j = 1; j < 22; j++) 
                for (int i = 1; i + (1 << j) - 1 <= n; i++) 
                    ST[i][j] = calc(ST[i][j - 1], ST[i + (1 << (j - 1))][j - 1]);
        }
        inline int query(int x, int y) {
            int l = first[x], r = first[y];
            if (l > r) std::swap(l, r);
            int k = log2(r - l + 1);
            return value[calc(ST[l][k], ST[r - (1 << k) + 1][k])];
        }
    }
    void adde(int u,int v) {
    	a[++tot].pre=head[u];
    	a[tot].to=v;
    	head[u]=tot;
    }
    void dfs1(int u,int fa) {
    	dfn[u]=++dfn[0]; acr[dfn[u]]=u;
    	dep[u]=dep[fa]+1,g[u]=fa;
    	for (int i=head[u];i;i=a[i].pre) {
    		int v=a[i].to; if (v==fa) continue;
    		dfs1(v,u);
    	}
    }
    int lca(int u,int v) {
    	return LCA::query(u,v);	
    }
    int dist(int u,int v) {
    	int l=lca(u,v);
    	return dep[u]+dep[v]-2*dep[l];
    }
    struct Seg {
    	int ls,rs,lp,rp,sum,val;
    	Seg() { sum=ls=rs=0; lp=inf; rp=-inf;}
    }tr[N*70];
    # define ls(x) tr[x].ls
    # define rs(x) tr[x].rs
    # define mid (l+r>>1)
    # define lson ls(x),l,mid
    # define rson rs(x),mid+1,r
    int cnt=0;
    void up(int &x) {
    	if (ls(x)!=0) tr[x].lp=min(tr[x].lp,tr[ls(x)].lp),tr[x].rp=max(tr[x].rp,tr[ls(x)].rp);
    	if (rs(x)!=0) tr[x].lp=min(tr[x].lp,tr[rs(x)].lp),tr[x].rp=max(tr[x].rp,tr[rs(x)].rp);
    	int ret=0;
    	if (ls(x)) ret+=tr[ls(x)].sum;
    	if (rs(x)) ret+=tr[rs(x)].sum;
    	if (ls(x) && rs(x) && tr[ls(x)].rp!=-inf && tr[rs(x)].lp!=inf) ret+=dist(acr[tr[ls(x)].rp],acr[tr[rs(x)].lp]);
    	tr[x].sum=ret;
    }
    void update(int &x,int l,int r,int pos,int d) {
    	if (!x) x=++cnt;
    	if (l==r) {
    		tr[x].val+=d;
    		if (tr[x].val>0) tr[x].lp=tr[x].rp=l;
    		else tr[x].lp=inf,tr[x].rp=-inf; 
    		tr[x].sum=0; 
    		return;
    	}
    	if (pos<=mid) update(lson,pos,d);
    	else update(rson,pos,d);
    	up(x);
    }
    void merge(int &x,int y,int l,int r) {
    	if (!x || !y) {x=x+y; return;}
    	if (l==r) {
    		tr[x].val+=tr[y].val;
    		if (tr[x].val>0) {
    			tr[x].lp=min(tr[x].lp,tr[y].lp);
    			tr[x].rp=max(tr[x].rp,tr[y].rp);
    		} else tr[x].lp=inf,tr[x].rp=-inf; 
    		tr[x].sum=0; 
    		return;
    	}
    	merge(ls(x),ls(y),l,mid);
    	merge(rs(x),rs(y),mid+1,r);
    	up(x);
    }
    void dfs2(int u,int fa) {
    	for (int i=head[u];i;i=a[i].pre) {
    		int v=a[i].to; if (v==fa) continue;
    		dfs2(v,u);merge(root[u],root[v],1,n);
    	}
    	ans+=tr[root[u]].sum;
    	if (tr[root[u]].lp!=inf && tr[root[u]].rp!=-inf)
    		ans+=dist(acr[tr[root[u]].lp],acr[tr[root[u]].rp]);
    }
    signed main()
    {
    	n=read();m=read();
    	memset(root,0,sizeof(root));
    	for (int i=2;i<=n;i++) {
    		int u=read(),v=read(); 
    		adde(u,v); adde(v,u);
    	}
    	dfs1(1,0);
    	LCA::init(1,n);
    	for (int i=1;i<=m;i++) {
    		int u=read(),v=read(); 
    		update(root[u],1,n,dfn[u],1);
    		update(root[u],1,n,dfn[v],1);
    		update(root[v],1,n,dfn[u],1);
    		update(root[v],1,n,dfn[v],1);
    		int l = g[lca(u,v)];
    		if (l) {
    			update(root[l],1,n,dfn[u],-2);
    			update(root[l],1,n,dfn[v],-2);
    		}
    	}
    	dfs2(1,0);
    	write(ans/4); putchar_('
    ');
    	flush(); 
    	return 0;
     } 
    
  • 相关阅读:
    sys、os 模块
    sh 了解
    TCP协议的3次握手与4次挥手过程详解
    python argparse(参数解析)模块学习(二)
    python argparse(参数解析)模块学习(一)
    Day17--Python--面向对象--成员
    Day16--Python--初识面向对象
    Day14--Python--函数二,lambda,sorted,filter,map,递归,二分法
    Day013--Python--内置函数一
    Day12--Python--生成器,生成器函数,推导式,生成器表达式
  • 原文地址:https://www.cnblogs.com/ljc20020730/p/11623520.html
Copyright © 2011-2022 走看看