zoukankan      html  css  js  c++  java
  • 【题解】 「NOI2020」命运 树形dp+线段树合并 LOJ3340

    Legend

    给定一棵 (n) 节点的树,你需要把边黑白染色。当然还有 (m) 个限制,限制是一条简单路径,且两端满足祖先-后代关系,表示这条路径上所有的边至少有一条是黑色的,问方案数对 (998 244 353) 取模的值。

    (1 le n le 500000)(0 le m le 500000)

    Editorial

    两端满足祖先-后代关系这个条件很奇怪,不妨从此下手。

    (dp_{i,j}) 表示从子树 (i) 里向上伸出来的链 **上端点深度最大是 (j) **的方案数量。

    现在考虑到了节点 (i),看看加入一棵子树 (k) 发生了什么:

    (dp'_{i,j}=sumlimits_{s=0}^{j} dp_{i,j} imes dp_{k,s} + sumlimits_{s=0}^{j-1} dp_{i,s} imes dp_{k,j}+ sumlimits_{s=0}^{dep_{i}}dp_{i,j} imes dp_{k,s})

    其中前两个转移是 ((i,k)) 这条边没选的,最后一个是这条边选了的。

    这样就可以 (O(n)) 进行转移。可以写出一个 (O(n cdot m{maxdep})) 的做法。

    把方程写成前缀和形式:

    [egin{aligned} dp'_{i,j}&=dp_{i,j} imes S_{k,j} + dp_{k,j} imes S_{i,j-1} + dp_{i,j} imes S_{k,dep_i} \ &=dp_{i,j} imes(S_{k,j}+S_{k,dep_i})+dp_{k,j} imes S_{i,j-1} end{aligned} ]

    考虑怎么快速算这一坨东西。发现前缀和其实可以通过线段树合并的时候顺带求出来。

    于是复杂度就变成了 (O(n log n))

    Code

    算是个套路,但是第一次见还是很有趣的。

    #include <bits/stdc++.h>
    
    #define LL long long
    #define debug(...) fprintf(stderr ,__VA_ARGS__)
    #define __FILE(x)
    	freopen(#x".in" ,"r" ,stdin);
    	freopen(#x".out" ,"w" ,stdout
    
    const int MX = 5e5 + 233;
    const LL MOD = 998244353;
    
    int read(){
    	char k = getchar(); int x = 0;
    	while(k < '0' || k > '9') k = getchar();
    	while(k >= '0' && k <= '9') x = x * 10 + k - '0' ,k = getchar();
    	return x;
    }
    
    std::vector<int> limit[MX];
    
    int head[MX] ,tot ,n;
    struct edge{
    	int node ,next;
    }h[MX << 1];
    void addedge(int u ,int v ,int flg = 1){
    	h[++tot] = (edge){v ,head[u]} ,head[u] = tot;
    	if(flg) addedge(v ,u ,0);
    }
    
    struct node{
    	int l ,r;
    	LL sum ,mul;
    	node *lch ,*rch;
    }*root[MX];
    
    node *newnode(int l ,int r){
    	node *x = new node;
    	x->l = l ,x->r = r;
    	x->sum = 0 ,x->mul = 1;
    	x->lch = x->rch = nullptr;
    	return x;
    }
    
    void domul(node *x ,LL v){
    	x->sum = x->sum * v % MOD;
    	x->mul = x->mul * v % MOD;
    }
    
    void pushdown(node *x){
    	if(x->mul != 1){
    		if(x->lch != nullptr) domul(x->lch ,x->mul);
    		if(x->rch != nullptr) domul(x->rch ,x->mul);
    		x->mul = 1;
    	}
    }
    
    void pushup(node *x){
    	x->sum = 0;
    	if(x->lch != nullptr) x->sum = x->lch->sum;
    	if(x->rch != nullptr) x->sum = (x->sum + x->rch->sum) % MOD;
    }
    
    void change(node *x ,int p ,int val){
    	if(x->l == x->r) return x->sum = val ,void();
    	int mid = (x->l + x->r) >> 1;
    	pushdown(x);
    	if(p <= mid){
    		if(x->lch == nullptr) x->lch = newnode(x->l ,mid);
    		change(x->lch ,p ,val);
    	}else{
    		if(x->rch == nullptr) x->rch = newnode(mid + 1 ,x->r);
    		change(x->rch ,p ,val);
    	}return pushup(x);
    }
    
    LL sum(node *x ,int l ,int r){
    	if(x == nullptr) return 0;
    	if(l <= x->l && x->r <= r) return x->sum;
    	pushdown(x);
    	int mid = (x->l + x->r) >> 1;
    	LL s = 0;
    	if(l <= mid) s = sum(x->lch ,l ,r);
    	if(r > mid) s = (s + sum(x->rch ,l ,r)) % MOD;
    	return s;	
    }
    
    node *merge(node *x ,node *y ,LL &s1 ,LL &s2){
    	if(x == nullptr){
    		if(y != nullptr){
    			s1 = (s1 + y->sum) % MOD;
    			domul(y ,s2); 
    		}
    		return y;
    	}
    	if(y == nullptr){
    		s2 = (s2 + x->sum) % MOD;
    		domul(x ,s1);
    		return x;
    	}
    	if(x->l == x->r){
    		LL tmps2 = s2;
    		s1 = (s1 + y->sum) % MOD;
    		s2 = (s2 + x->sum) % MOD;
    		domul(x ,s1);
    		x->sum = (x->sum + y->sum * tmps2) % MOD;
    	}
    	else{
    		pushdown(x) ,pushdown(y);
    		x->lch = merge(x->lch ,y->lch ,s1 ,s2);
    		x->rch = merge(x->rch ,y->rch ,s1 ,s2);
    		pushup(x);
    	}
    	return x;
    }
    
    int dep[MX];
    void dfs(int x ,int f ,int depth){
    	dep[x] = depth;
    	for(int i = head[x] ,d ; i ; i = h[i].next){
    		if((d = h[i].node) == f) continue;
    		dfs(d ,x ,depth + 1);
    	}
    
    	root[x] = newnode(0 ,n);
    	int mx = 0;
    	for(auto i : limit[x]){
    		if(dep[i] > dep[x]) continue;
    		// change(root[x] ,dep[i] ,1);
    		mx = std::max(mx ,dep[i]);
    	}
    	change(root[x] ,mx ,1);
    
    	for(int i = head[x] ,d ; i ; i = h[i].next){
    		if((d = h[i].node) == f) continue;
    		LL tmp = sum(root[d] ,0 ,dep[x]) ,tmp2 = 0;
    		merge(root[x] ,root[d] ,tmp ,tmp2);
    	}
    	// debug("sum dp[%d] = %lld
    " ,x ,sum(root[x] ,0 ,n));
    }
    
    int main(){
    	n = read();
    	for(int i = 1 ,u ,v ; i < n ; ++i){
    		u = read() ,v = read();
    		addedge(u ,v);
    	}
    	int m = read();
    	for(int i = 1 ,u ,v ; i <= m ; ++i){
    		u = read() ,v = read();
    		limit[u].push_back(v);
    		limit[v].push_back(u);
    	}
    	dfs(1 ,0 ,1);
    	printf("%lld
    " ,sum(root[1] ,0 ,0));
    	return 0;
    }
    
  • 相关阅读:
    在yii中使用Filter实现RBAC权限自动判断
    关于WEB设计透明和阴影
    一句话扯扯数据结构的概念点
    Console API Google 浏览器开发人员工具使用
    git提交项目时候,忽略一些文件
    学习笔记 如何解决IE6 position:fixed固定定位问题{转载}
    [转载]yii jquery折叠、弹对话框、拖拽、滑动条、ol和ul列表、局部内容切换
    Jquery 常用方法经典总结【砖】
    PHP中冒号、endif、endwhile、endfor这些都是什么
    [转载]救命的PHP代码
  • 原文地址:https://www.cnblogs.com/imakf/p/13735181.html
Copyright © 2011-2022 走看看