zoukankan      html  css  js  c++  java
  • 【题解】彩色树 51nod 1868 虚树 树上dp

    Prelude

    题目在这里:ο(=•ω<=)ρ⌒☆


    Solution

    蒟蒻__stdcall的第一道虚树题qaq。
    首先很容易发现,这个排列是假的。
    我们只需要求出每对点之间的颜色数量,然后求个和,然后再乘以((n-1)!)再乘以(2)就好啦!
    如何求出“每对点之间的颜色数量之和”呢?
    似乎点分可以做,并且fc确实写出了点分的做法,但是有更简(ma)单(nong)的虚树做法。
    我们对每种颜色分开考虑,对于每种颜色(c),我们考虑有多少条路径经过了颜色(c),然后再求和,就可以了。
    注意到“有多少条路径经过颜色(c)”,可以转化为“总的路径条数”减去“不经过颜色(c)的路径条数”。
    “总的路径条数”等于(frac{n(n-1)}{2})
    然后我们对所有颜色(c)的点建出虚树,在虚树上dp就可以求出“不经过颜色(c)的路径条数”了。
    如何dp?
    考虑去掉所有的颜色(c)的点,剩下了一个个连通块,那么每个连通块内部的所有路径都不经过颜色(c),并且跨越连通块的路径一定经过颜色(c)
    然后就是dp求出每个连通块的大小就可以了。
    这个东西。。。应该不用再讲了叭。。。我也不知道怎么解释了,要不看代码叭QAQ。


    Code

    #include <cstring>
    #include <algorithm>
    #include <cstdio>
    #include <stack>
    #include <vector>
    #include <cassert>
    
    using namespace std;
    typedef long long ll;
    typedef vector<int>::iterator viter;
    const int MAXN = 100010;
    const int MOD = 1e9+7;
    int _w;
    
    int n, a[MAXN];
    vector<int> col[MAXN];
    
    namespace Tree {
    	int head[MAXN], nxt[MAXN<<1], to[MAXN<<1], m;
    	void init() {
    		m = 0;
    		memset(head, -1, sizeof head);
    	}
    	void adde( int u, int v ) {
    		to[m] = v, nxt[m] = head[u], head[u] = m++;
    		to[m] = u, nxt[m] = head[v], head[v] = m++;
    	}
    }
    
    namespace DFS {
    	int dfn[MAXN], dfnc, top[MAXN], son[MAXN], pa[MAXN], dep[MAXN], sz[MAXN];
    	void dfs1( int u, int fa, int d ) {
    		using namespace Tree;
    		sz[u] = 1, dep[u] = d, pa[u] = fa;
    		for( int i = head[u]; ~i; i = nxt[i] ) {
    			int v = to[i];
    			if( v == fa ) continue;
    			dfs1(v, u, d+1);
    			sz[u] += sz[v];
    			if( sz[v] > sz[son[u]] ) son[u] = v;
    		}
    	}
    	void dfs2( int u, int tp ) {
    		using namespace Tree;
    		dfn[u] = ++dfnc, top[u] = tp;
    		if( son[u] ) dfs2( son[u], tp );
    		for( int i = head[u]; ~i; i = nxt[i] ) {
    			int v = to[i];
    			if( v == pa[u] || v == son[u] ) continue;
    			dfs2(v, v);
    		}
    	}
    	void solve() {
    		dfs1(1, 0, 1);
    		dfs2(1, 1);
    	}
    	int lca( int u, int v ) {
    		while( top[u] != top[v] ) {
    			if( dep[top[u]] < dep[top[v]] )
    				swap(u, v);
    			u = pa[top[u]];
    		}
    		return dep[u] < dep[v] ? u : v;
    	}
    	int findson( int u, int v ) {
    		while( top[u] != top[v] && pa[top[v]] != u )
    			v = pa[top[v]];
    		if( top[u] == top[v] ) return son[u];
    		else return top[v];
    	}
    }
    
    int cnt[MAXN];
    void prelude() {
    	using namespace Tree;
    	using DFS::sz;
    	using DFS::pa;
    	for( int u = 1; u <= n; ++u )
    		for( int i = head[u]; ~i; i = nxt[i] ) {
    			int v = to[i];
    			if( v == pa[u] ) continue;
    			cnt[u] = int((cnt[u] + (ll)sz[v] * (sz[v]-1) / 2 % MOD) % MOD);
    		}
    }
    
    int vistm[MAXN];
    stack<int> stk;
    bool cmp_dfn( int i, int j ) {
    	using DFS::dfn;
    	return dfn[i] < dfn[j];
    }
    void vt_adde( int u, int v, int id ) {
    	if( vistm[u] != id ) {
    		vistm[u] = id;
    		Tree::head[u] = -1;
    	}
    	if( vistm[v] != id ) {
    		vistm[v] = id;
    		Tree::head[v] = -1;
    	}
    	Tree::adde(u, v);
    }
    int build( vector<int> &vec, int id ) {
    	using DFS::dep;
    	Tree::m = 0;
    	sort( vec.begin(), vec.end(), cmp_dfn );
    	for( viter it = vec.begin(); it != vec.end(); ++it ) {
    		int u = *it;
    		if( stk.empty() ) {
    			stk.push(u);
    		} else {
    			int lca = DFS::lca(u, stk.top());
    			while( !stk.empty() && DFS::dep[stk.top()] > dep[lca] ) {
    				int v = stk.top(); stk.pop();
    				if( stk.empty() || DFS::dep[stk.top()] < dep[lca] ) {
    					vt_adde(v, lca, id);
    				} else {
    					vt_adde(v, stk.top(), id);
    				}
    			}
    			if( stk.empty() || stk.top() != lca )
    				stk.push(lca);
    			stk.push(u);
    		}
    	}
    	while( !stk.empty() ) {
    		int u = stk.top(); stk.pop();
    		if( stk.empty() ) return u;
    		vt_adde(u, stk.top(), id);
    	}
    	return assert(0), 0;
    }
    
    int vt_ans, f[MAXN];
    void vt_dfs( int u, int fa, int c ) {
    	using namespace Tree;
    	using DFS::sz;
    	for( int i = head[u]; ~i; i = nxt[i] ) {
    		int v = to[i];
    		if( v == fa ) continue;
    		vt_dfs(v, u, c);
    	}
    	if( a[u] == c ) {
    		f[u] = 0;
    		vt_ans = (vt_ans + cnt[u]) % MOD;
    		for( int i = head[u]; ~i; i = nxt[i] ) {
    			int v = to[i];
    			if( v == fa ) continue;
    			int son = DFS::findson(u, v);
    			vt_ans = int((vt_ans - (ll)sz[son] * (sz[son]-1) / 2 % MOD + MOD) % MOD);
    			int tmp = sz[son] - sz[v] + f[v];
    			vt_ans = int((vt_ans + (ll)tmp * (tmp-1) / 2 % MOD) % MOD);
    		}
    	} else {
    		f[u] = sz[u];
    		for( int i = head[u]; ~i; i = nxt[i] ) {
    			int v = to[i];
    			if( v == fa ) continue;
    			f[u] = f[u] - sz[v] + f[v];
    		}
    	}
    }
    int calc( int rt, int c ) {
    	using DFS::sz;
    	vt_ans = 0;
    	vt_dfs(rt, 0, c);
    	int tmp = n - sz[rt] + f[rt];
    	vt_ans = int((vt_ans + (ll)tmp * (tmp-1) / 2 % MOD) % MOD);
    	return vt_ans;
    }
    
    int solve( int c ) {
    	if( col[c].empty() ) return 0;
    	int rt = build( col[c], c );
    	if( vistm[rt] != c ) {
    		vistm[rt] = c;
    		Tree::head[rt] = -1;
    	}
    	// printf( "rt[%d] = %d
    ", c, rt );
    	int ans = calc(rt, c);
    	ans = int(((ll)n*(n-1)/2 % MOD - ans + MOD) % MOD);
    	return ans;
    }
    
    int main() {
    	_w = scanf( "%d", &n );
    	for( int i = 1; i <= n; ++i ) {
    		_w = scanf( "%d", a+i );
    		col[a[i]].push_back(i);
    	}
    	Tree::init();
    	for( int i = 0; i < n-1; ++i ) {
    		int u, v;
    		_w = scanf( "%d%d", &u, &v );
    		Tree::adde(u, v);
    	}
    	DFS::solve(), prelude();
    	int ans = 0;
    	for( int i = 1; i <= n; ++i ) {
    		int tmp = solve(i);
    		ans = (ans + tmp) % MOD;
    		// printf( "path[%d] = %d
    ", i, tmp );
    	}
    	// printf( "path = %d
    ", ans );
    	for( int i = 2; i <= n-1; ++i )
    		ans = int((ll)ans * i % MOD);
    	ans = ans * 2 % MOD;
    	printf( "%d
    ", ans );
    	return 0;
    }
    
  • 相关阅读:
    Python3之random模块常用方法
    Go语言学习笔记(九)之数组
    Go语言学习笔记之简单的几个排序
    Go语言学习笔记(八)
    Python3之logging模块
    Go语言学习笔记(六)
    123. Best Time to Buy and Sell Stock III(js)
    122. Best Time to Buy and Sell Stock II(js)
    121. Best Time to Buy and Sell Stock(js)
    120. Triangle(js)
  • 原文地址:https://www.cnblogs.com/mlystdcall/p/7953483.html
Copyright © 2011-2022 走看看