zoukankan      html  css  js  c++  java
  • LOJ #2537. 「PKUWC 2018」Minimax (线段树合并 优化dp)

    题意

    (C) 有一棵 (n) 个结点的有根树,根是 (1) 号结点,且每个结点最多有两个子结点。

    定义结点 (x) 的权值为:

    1.若 (x) 没有子结点,那么它的权值会在输入里给出,保证这类点中每个结点的权值互不相同

    2.若 (x) 有子结点,那么它的权值有 (p_x) 的概率是它的子结点的权值的最大值,有 (1-p_x) 的概率是它的子结点的权值的最小值。

    现在小 (C) 想知道,假设 (1) 号结点的权值有 (m) 种可能性,权值第 (i)的可能性的权值是 (V_i) ,它的概率为 (D_i(D_i>0)) ,求:

    [displaystyle sum _{i=1} ^ {m} i cdot V_i cdot D_i^2 ]

    你需要输出答案对 (998244353) 取模的值。

    对于 (40\%) 的数据,有 (1leq nleq 5000)

    对于 (100\%) 的数据,有 (1leq nleq 3 imes 10^5, 1leq w_ileq 10^9)

    题解

    首先考虑 (O(n^2))dp , 令 (dp_{u,i})(u) 号点 , 取到排名为 (i) 权值的概率 .

    这个应该比较容易转移 , 考虑枚举一个儿子取的值 , 然后对于它的贡献 就分为它最小和它最大的两种去计数就行了 .

    (ls, rs)(u) 的左/右儿子 , 枚举两个子树的取值 (i, j) , 令 (coef = dp[ls][i] * dp[rs][j]) 方程就是 :

    [coef * prob[u] o dp[u][max(i, j)] \ coef * (1 - prob[u]) o dp[u][min(i, j)] ]

    如果我们只枚举有效状态那么就是 (O(n ^ 2)) 的啦,因为对于一对点对他们的贡献只会在 (lca) 合并。


    然后考虑优化 , 类似于这种状态数与 (size_u) 有关的 dp .

    常常可以考虑 线段树合并 or 启发式合并 来优化时间复杂度 .

    一开始想直接 启发式合并 在线段树上操作 发现细节好多 而且不好维护 ... 然后就弃掉了 看了一波 LOJ 最短代码 qwq

    诶 好像很好写啊 , 原来直接线段树合并就行了 . qwq

    考虑维护一颗线段树 , 每个点维护两个值 (sumv, mult) 代表 区间和 以及 区间乘法的标记 .

    然后每个叶子 代表一个 dp 值 , 然后每个区间就可以维护这段区间的 dp 值之和 .

    我们一边合并一边算到当前区间 , 对于两个线段树 dp 值存在的贡献 (sumx, sumy) (也就是前面方程中需要乘上后面的两个东西) .

    如果当前区间只有一个子树 , 打下乘法标记 , 直接返回就行了 . 否则继续递归下去合并解决 .

    时间复杂度就是 $ O(sum_{i=1}^{n} minsize) = O(n log n)$ .

    这是因为每个点合并上去 大小至少翻倍 . 意味着每个点最多被计算 ((log n)) 次 , 最后复杂度就是 (O(n log n)) .

    代码

    [40pts ]

    #include <bits/stdc++.h>
    #define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
    #define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
    #define Set(a, v) memset(a, v, sizeof(a))
    using namespace std;
    
    inline bool chkmin(int &a, int b) {return b < a ? a = b, 1 : 0;}
    inline bool chkmax(int &a, int b) {return b > a ? a = b, 1 : 0;}
    
    inline int read() {
        int x = 0, fh = 1; char ch = getchar();
        for (; !isdigit(ch); ch = getchar() ) if (ch == '-') fh = -1;
        for (; isdigit(ch); ch = getchar() ) x = (x << 1) + (x << 3) + (ch ^ 48);
        return x * fh;
    }
    
    void File() {
    #ifdef zjp_shadow
    	freopen ("2537.in", "r", stdin);
    	freopen ("2537.out", "w", stdout);
    #endif
    }
    
    typedef long long ll;
    const ll Mod = 998244353;
    ll fpm(ll x, int power) {
    	ll res = 1;
    	for (; power; power >>= 1, (x *= x) %= Mod)
    		if (power & 1) (res *= x) %= Mod;
    	return res;
    }
    
    typedef long long ll;
    const int N = 5100, inv = fpm(10000, Mod - 2);
    int n, fa[N], ch[N][2], tot[N], val[N], rk[N], Leaf;
    ll dp[N][N], p[N], Pre[N][2], Suf[N][2];
    
    #define ls(o) ch[o][0]
    #define rs(o) ch[o][1]
    void Dp(int u) {
    	if (!u) return ; Dp(ls(u)); Dp(rs(u));
    	if (!tot[u]) { dp[u][rk[u]] = 1; }
    	else if (tot[u] == 1) {
    		For (i, 1, Leaf) dp[u][i] = dp[ls(u)][i];
    	} else {
    		For (son, 0, 1) {
    			For (i, 1, Leaf)
    				Pre[i][son] = (Pre[i - 1][son] + dp[ch[u][son]][i]) % Mod;
    			Fordown (i, Leaf, 1)
    				Suf[i][son] = (Suf[i + 1][son] + dp[ch[u][son]][i]) % Mod;
    		}
    		For (i, 1, Leaf) For (son, 0, 1) {
    			(dp[u][i] += dp[ch[u][son]][i] * Pre[i - 1][son ^ 1] % Mod * p[u] % Mod 
    			 + dp[ch[u][son]][i] * Suf[i + 1][son ^ 1] % Mod * (Mod + 1 - p[u]) % Mod) %= Mod;
    		}
    	}
    }
    
    int main () {
    	File();
    	n = read();
    	For (i, 1, n) fa[i] = read(), ch[fa[i]][tot[fa[i]] ++] = i;
    	For (i, 1, n)
    		if (!tot[i]) rk[i] = val[++ Leaf] = read();
    		else p[i] = 1ll * read() * inv % Mod;
    
    	sort(val + 1, val + Leaf + 1);
    	For (i, 1, n) rk[i] = lower_bound(val + 1, val + Leaf + 1, rk[i]) - val;
    
    	Dp(1);
    
    	ll ans = 0;
    	For (i, 1, n)
    		(ans += 1ll * i * val[i] % Mod * dp[1][i] % Mod * dp[1][i] % Mod) %= Mod;
    	printf ("%lld
    ", ans);
        return 0;
    }
    

    [100pts ]

    #include <bits/stdc++.h>
    #define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
    #define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
    #define Set(a, v) memset(a, v, sizeof(a))
    using namespace std;
    
    inline bool chkmin(int &a, int b) {return b < a ? a = b, 1 : 0;}
    inline bool chkmax(int &a, int b) {return b > a ? a = b, 1 : 0;}
    
    inline int read() {
        int x = 0, fh = 1; char ch = getchar();
        for (; !isdigit(ch); ch = getchar() ) if (ch == '-') fh = -1;
        for (; isdigit(ch); ch = getchar() ) x = (x << 1) + (x << 3) + (ch ^ 48);
        return x * fh;
    }
    
    void File() {
    #ifdef zjp_shadow
    	freopen ("2537.in", "r", stdin);
    	freopen ("2537.out", "w", stdout);
    #endif
    }
    
    typedef long long ll;
    const ll Mod = 998244353;
    ll fpm(ll x, int power) {
    	ll res = 1;
    	for (; power; power >>= 1, (x *= x) %= Mod)
    		if (power & 1) (res *= x) %= Mod;
    	return res;
    }
    
    typedef long long ll;
    
    const int N = 3e5 + 1e3, inv = fpm(10000, Mod - 2);
    int val[N], rk[N];
    
    #define ls(o) ch[o][0]
    #define rs(o) ch[o][1]
    #define lson ls(o), l, mid
    #define rson rs(o), mid + 1, r
    const int Maxnode = 6e6 + 1e3;
    #define Mult(o, val) (sumv[o] *= (val)) %= Mod, (mult[o] *= (val)) %= Mod;
    struct Segment_Tree {
    	int rt[Maxnode], ch[Maxnode][2], Size; ll sumv[Maxnode], mult[Maxnode];
    
    	inline void push_up(int o) { sumv[o] = (sumv[ls(o)] + sumv[rs(o)]) % Mod; }
    
    	inline void push_down(int o) { 
    		if (mult[o] <= 1) return ; Mult(ls(o), mult[o]); Mult(rs(o), mult[o]); mult[o] = 1;
    	}
    
    	void Update(int &o, int l, int r, int up, ll uv) {
    		if (!o) o = (++ Size); mult[o] = 1;
    		if (l == r) { (sumv[o] += uv) %= Mod; return ; } int mid = (l + r) >> 1;
    		push_down(o); if (up <= mid) Update(lson, up, uv); else Update(rson, up, uv); push_up(o);
    	}
    
    	int Merge(int x, int y, ll sumx, ll sumy, ll probmax, ll probmin) {
    		if (!y) { Mult(x, sumy); return x; }
    		if (!x) { Mult(y, sumx); return y; }
    		push_down(x); push_down(y);
    		ll x0 = sumv[ls(x)], x1 = sumv[rs(x)], y0 = sumv[ls(y)], y1 = sumv[rs(y)];
    		ls(x) = Merge(ls(x), ls(y), (sumx + probmin * x1) % Mod, (sumy + probmin * y1) % Mod, probmax, probmin);
    		rs(x) = Merge(rs(x), rs(y), (sumx + probmax * x0) % Mod, (sumy + probmax * y0) % Mod, probmax, probmin);
    		push_up(x); return x;
    	}
    
    	inline ll Calc(int o, int l, int r) {
    		if (l == r) return 1ll * l * val[l] % Mod * sumv[o] % Mod * sumv[o] % Mod;
    		int mid = (l + r) >> 1; push_down(o);
    		return (Calc(lson) + Calc(rson)) % Mod;
    	}
    } T;
    
    int n, fa[N], ch[N][2], tot[N], Leaf;
    ll p[N];
    
    void Dp(int u) {
    	if (!u) return ; Dp(ls(u)); Dp(rs(u));
    	if (!tot[u]) T.Update(T.rt[u], 1, Leaf, rk[u], 1);
    	else if (tot[u] == 1) T.rt[u] = T.rt[ls(u)];
    	else T.rt[u] = T.Merge(T.rt[ls(u)], T.rt[rs(u)], 0, 0, p[u], (Mod + 1 - p[u]) % Mod);
    }
    
    int main () {
    	File();
    	n = read();
    	For (i, 1, n) fa[i] = read(), ch[fa[i]][tot[fa[i]] ++] = i;
    	For (i, 1, n)
    		if (!tot[i]) rk[i] = val[++ Leaf] = read();
    		else p[i] = 1ll * read() * inv % Mod;
    
    	sort(val + 1, val + Leaf + 1);
    	For (i, 1, n) rk[i] = lower_bound(val + 1, val + Leaf + 1, rk[i]) - val;
    
    	Dp(1);
    
    	printf ("%lld
    ", T.Calc(T.rt[1], 1, Leaf));
        return 0;
    }
    
  • 相关阅读:
    springcloud(十五):搭建Zuul微服务网关
    springcloud(十四)、ribbon负载均衡策略应用案例
    springcloud(十三):Ribbon客户端负载均衡实例
    springcloud(十二):Ribbon客户端负载均衡介绍
    springcloud(十):熔断监控Hystrix Dashboard
    springcloud(九):熔断器Hystrix和Feign的应用案例
    springcloud(八):熔断器Hystrix
    springcloud(七): 使用Feign调用Eureka Server客户端服务
    springcloud(六):Eureka提供数据的客户端连接Docker的mysql
    雏龙计划
  • 原文地址:https://www.cnblogs.com/zjp-shadow/p/9073578.html
Copyright © 2011-2022 走看看