zoukankan      html  css  js  c++  java
  • day9T1改错记

    题面

    给出一棵(n)个节点的树,和(m)条非树边,第(i)条非树边连接(u_i)(v_i),有(p_i)的概率不出现,问期望有多少条非树边至多出现在一个简单环中。答案对(998244353)取模,输入的(p)也是取模后的概率。

    (n, m le 1e6)

    解析瞎扯

    对着标程看半天没看懂,一问才反应过来节点的标记实际上是它父边的标记……

    于是就成了树上前缀和(积?)

    代码

    #include <cstring>
    #include <cstdio>
    #include <iostream>
    #include <map>
    #define mod 998244353ll
    
    typedef long long LL;
    typedef std::pair<int, int> pii;
    const int maxn = (int)1e6 + 6;
    
    LL qpower(LL x, LL y) {
    	LL res = 1;
    	while (y) {
    		if (y & 1) res = res * x % mod;
    		y >>= 1;
    		x = x * x % mod;
    	}
    	return res;
    }
    struct Edge {
    	int to, nxt;
    	Edge(int _t = 0, int _n = 0):to(_t), nxt(_n) {}
    } edge[maxn << 1];
    int n, m, cnt, head[maxn], dep[maxn], f[maxn][21];
    int x[maxn], y[maxn], lca[maxn], befx[maxn], befy[maxn], prob[maxn];
    struct Tag {
    	int val, zr;
    	Tag(int v, int z) : val(v), zr(z) {}
    	Tag(int x_ = 1) {
    		if (!x_) val = zr = 1;
    		else val = x_, zr = 0;
    	}
    	inline int v() { return zr ? 0 : val; }
    	inline Tag operator *(const Tag &rhs) const { return Tag(1LL * val * rhs.val % mod, zr + rhs.zr); }
    	inline Tag operator /(const Tag &rhs) const { return Tag(1LL * val * qpower(rhs.val, mod - 2) % mod, zr - rhs.zr); }
    } t[maxn], sum[maxn], res[maxn];
    std::map<pii, Tag> mp[maxn];
    typedef std::map<pii, Tag>::iterator iter;
    
    void add(int, int);
    void dfs(int, int);
    void calc(int);
    char gc();
    int read();
    
    int main() {
    	freopen("cactus.in", "r", stdin);
    	freopen("cactus.out", "w", stdout);
    
    	n = read(), m = read();
    	for (int i = 1; i < n; ++i) {
    		int u = read(), v = read();
    		add(u, v), add(v, u);
    	}
    	dfs(1, 0);
    	for (int i = 1; i <= m; ++i) {
    		x[i] = read(), y[i] = read(), prob[i] = read();
    		Tag inv = Tag(1) / Tag(prob[i]);
    		if (dep[x[i]] < dep[y[i]]) std::swap(x[i], y[i]);
    		int tmpx = x[i], tmpy = y[i];
    		res[i] = inv;
    		for (int j = 20; ~j; --j) if (dep[f[tmpx][j]] > dep[tmpy]) tmpx = f[tmpx][j];
    		if (f[tmpx][0] == tmpy) {
    			befx[i] = tmpx, sum[tmpx] = sum[tmpx] * prob[i];
    			t[x[i]] = t[x[i]] * prob[i], t[y[i]] = t[y[i]] * inv;
    		} else {
    			if (dep[tmpx] > dep[tmpy]) tmpx = f[tmpx][0];
    			for (int j = 20; ~j; --j) if (f[tmpx][j] != f[tmpy][j]) tmpx = f[tmpx][j], tmpy = f[tmpy][j];
    			befx[i] = tmpx, befy[i] = tmpy;
    			lca[i] = f[tmpx][0];
    			t[x[i]] = t[x[i]] * prob[i], t[y[i]] = t[y[i]] * prob[i];
    			t[lca[i]] = t[lca[i]] * inv * inv;
    			sum[tmpx] = sum[tmpx] * prob[i], sum[tmpy] = sum[tmpy] * prob[i];
    			if (tmpx > tmpy) std::swap(tmpx, tmpy);
    			Tag &T = mp[lca[i]][std::make_pair(tmpx, tmpy)];
    			T = T * inv;
    		}
    	}
    	calc(1);
    	LL ans = 0;
    	for (int i = 1; i <= m; ++i) {
    		Tag inv = res[i], ret = Tag(1 - prob[i]) * inv;
    		if (!lca[i]) ans = (ans + (ret * sum[x[i]] / sum[befx[i]] * t[befx[i]]).v()) % mod;
    		else {
    			Tag l = sum[x[i]] / sum[befx[i]] * t[befx[i]], r = sum[y[i]] / sum[befy[i]] * t[befy[i]];
    			Tag d = mp[lca[i]][std::make_pair(std::min(befy[i], befx[i]), std::max(befx[i], befy[i]))];
    			ans = (ans + (ret * l * r * d).v()) % mod;
    		}
    	}
    	printf("%lld
    ", ans < 0 ? ans + mod : ans);
    
    	return 0;
    }
    void dfs(int x, int fa) {
    	f[x][0] = fa, dep[x] = dep[fa] + 1;
    	for (int i = 1; i <= 20; ++i) f[x][i] = f[f[x][i - 1]][i - 1];
    	for (int p = head[x]; p; p = edge[p].nxt) {
    		int y = edge[p].to;
    		if (y == fa) continue;
    		dfs(y, x);
    	}
    }
    void calc(int x) {
    	if (f[x][0]) sum[x] = sum[f[x][0]] * sum[x];
    	for (int p = head[x]; p; p = edge[p].nxt) {
    		int y = edge[p].to;
    		if (y == f[x][0]) continue;
    		calc(y);
    		t[x] = t[x] * t[y];
    	}
    }
    inline void add(int bgn, int end) {
    	edge[++cnt] = edge(end, head[bgn]);
    	head[bgn] = cnt;
    }
    inline char gc() {
    	static char buf[1000000], *p1, *p2;
    	if (p1 == p2) p1 = (p2 = buf) + fread(buf, 1, 1000000, stdin);
    	return p1 == p2 ? EOF : *p2++;
    }
    inline int read() {
    	int res = 0; char ch = gc();
    	while (ch < '0' || ch > '9') ch = gc();
    	while (ch >= '0' && ch <= '9') res = (res << 1) + (res << 3) + ch - '0', ch = gc();
    	return res;
    }
    //Rhein_E
    
  • 相关阅读:
    多媒体(2):WAVE文件格式分析
    多媒体(1):MCI接口编程
    EM算法(4):EM算法证明
    EM算法(3):EM算法运用
    EM算法(2):GMM训练算法
    EM算法(1):K-means 算法
    Support Vector Machine (3) : 再谈泛化误差(Generalization Error)
    Support Vector Machine (2) : Sequential Minimal Optimization
    Neural Network学习(二)Universal approximator :前向神经网络
    Neural Network学习(一) 最早的感知机:Perceptron of Rosenblatt
  • 原文地址:https://www.cnblogs.com/Rhein-E/p/10518668.html
Copyright © 2011-2022 走看看