zoukankan      html  css  js  c++  java
  • 「CSP-S 2020」函数调用(拓扑排序+DP)

    Address

    LOJ3381
    LuoguP7077

    Solution

    因为加是单点加,乘是全体乘,所以考虑计算后面的乘对前面的加的影响。

    也就是说,对于某次执行 (T_j=1) 的操作 (a_p+=v),设在它之后执行的 (T_j=2) 的操作的 (prod V_j=x)。那么计算最终答案的时候,只要把 (a_p+=v imes x) 即可。

    对于 (T_j=3) 的操作,题目说保证不会出现递归(即不会直接或间接地调用本身)。因此建一张 DAG,如果函数 (u) 直接调用了函数 (v),那么连一条 (u→v) 的边。方便起见,再建一个点 (m+1),向 (Q)(f_i) 都连一条边。

    接下来开始暴力,我们记一个 (prod),表示当前访问过的 (T_j=2) 节点的 (prod V_j)。(重复访问就重复计算)

    (m+1) 开始 DFS(注意出边的顺序要反过来,因为是后面的乘对前面的加的影响)。DFS 到 (u) 的时候,如果 (T_j=2)(prod imes=V_j),如果 (T_j=1)(a_{P_j}+=V_j imes prod),如果 (T_j=3),就什么都不做。

    怎么优化这个暴力?

    考虑对于一个点 (u),它连向的点分别为 (v_1,v_2,...,v_k)。那么 DFS 到 (u) 之后,设当前的 (prod)(s),接下来肯定是 DFS (v_1),那么执行完 DFS (v_1),准备 DFS (v_2) 的时候,(prod) 是多少?

    预处理出 (dp_u) 表示从 (u) 开始 DFS,经过的所有 (T_j=2) 节点的 (prod V_j),按拓扑序倒序转移即可。

    那么上述的 (prod) 就是 (dp_{v_1} imes s),以此类推,准备 DFS (v_i) 的时候,(prod) 就是 (prod_{j=1}^{i-1}dp_{v_j} imes s)

    我们可以这样描述这个 DFS:从 (u) 开始,带着大小为 (prod) 的标记走下去,接下来,对于每个 (v_i),带着大小为 (prod_{j=1}^{i-1}dp_{v_j} imes s) 的标记走下去。也就是说,我们不用把 (v_1sim v_{i-1}) 都 DFS 一遍,就可以知道 (v_i) 的标记大小,它仅仅取决于所有连向它的 (u)

    我们记 (tag_u) 表示点 (u) 的标记大小。这个 (tag) 有什么用呢?我们发现所有的 (T_jin{1,2})(j) 都是底层节点,没有出边,所以如果 (T_j=1),我们求出 (tag_j) 之后,直接让 (a_{P_j}+=V_j imes tag_j) 即可。

    根据上述分析,对于一个点 (v),只要知道所有连向它的 (u)(tag_u),即可用形如 (tag_v=sum_{u→v}tag_u imes prod_{xin pre(u,v)}dp_x) 的式子求出 (tag_v)。按照拓扑序转移即可。

    注意把所有 (m+1) 走不到的点和边删掉。

    时间复杂度 (O(n+m+Q+sum C_j))

    Code

    #include <bits/stdc++.h>
    
    using namespace std;
    
    #define ll long long
    
    template <class t>
    inline void read(t & res)
    {
    	char ch;
    	while (ch = getchar(), !isdigit(ch));
    	res = ch ^ 48;
    	while (ch = getchar(), isdigit(ch))
    		res = res * 10 + (ch ^ 48);
    }
    
    template <class t>
    inline void print(t x)
    {
    	if (x > 9) print(x / 10);
    	putchar(x % 10 + 48);
    }
    
    const int N = 1e5 + 15, M = 2e6 + 15, mod = 998244353;
    
    int adj[N], nxt[M], go[M], val[N], pos[N], typ[N], n, m, q, tag[N];
    int f[N], deg[N], seq[N], cnt, a[N], num;
    bool vis[N];
    
    inline void add(int &x, int y)
    {
    	(x += y) >= mod && (x -= mod);
    }
    
    inline void link(int x, int y)
    {
    	nxt[++num] = adj[x];
    	adj[x] = num;
    	go[num] = y;
    	deg[y]++;
    }
    
    inline void dfs(int u)
    {
    	if (vis[u]) return;
    	vis[u] = 1;
    	for (int i = adj[u]; i; i = nxt[i]) dfs(go[i]);
    }
    
    inline void topsort()
    {
    	queue<int>q;
    	int i, j;
    	q.push(m + 1);
    	seq[cnt = 1] = m + 1;
    	while (!q.empty())
    	{
    		int u = q.front();
    		q.pop();
    		for (i = adj[u]; i; i = nxt[i])
    		{
    			int v = go[i];
    			if (!vis[v]) continue;
    			deg[v]--;
    			if (!deg[v]) q.push(v), seq[++cnt] = v;
    		}
    	}
    	for (i = cnt; i >= 1; i--)
    	{
    		int u = seq[i];
    		for (j = adj[u]; j; j = nxt[j])
    		{
    			int v = go[j];
    			f[u] = (ll)f[u] * f[v] % mod;
    		}
    	}
    }
    
    inline void solve()
    {
    	int i, j;
    	for (i = 1; i <= cnt; i++)
    	{
    		int u = seq[i], pre = 1;
    		for (j = adj[u]; j; j = nxt[j])
    		{
    			int v = go[j];
    			add(tag[v], (ll)tag[u] * pre % mod);
    			pre = (ll)pre * f[v] % mod;
    		}
    	}
    }
    
    int main()
    {
    	freopen("call.in", "r", stdin);
    	freopen("call.out", "w", stdout);
    	read(n);
    	int i, j, k, x;
    	for (i = 1; i <= n; i++) read(a[i]);
    	read(m);
    	for (i = 1; i <= m; i++)
    	{
    		read(typ[i]);
    		f[i] = 1;
    		if (typ[i] == 1) read(pos[i]), read(val[i]);
    		else if (typ[i] == 2) read(val[i]), f[i] = val[i];
    		else
    		{
    			read(k);
    			for (j = 1; j <= k; j++)
    			{
    				read(x);
    				link(i, x);
    			}
    		}
    	}
    	read(q);
    	for (i = 1; i <= q; i++)
    	{
    		read(x);
    		link(m + 1, x);
    	}
    	dfs(m + 1);
    	for (i = 1; i <= m + 1; i++)
    		for (j = adj[i]; j; j = nxt[j])
    		{
    			k = go[j];
    			if (!vis[k] || !vis[i]) deg[k]--;
    		}
    	f[m + 1] = tag[m + 1] = 1;
    	topsort();
    	solve();
    	for (i = 1; i <= n; i++) a[i] = (ll)a[i] * f[m + 1] % mod;
    	for (i = 1; i <= m; i++)
    		if (typ[i] == 1) add(a[pos[i]], (ll)val[i] * tag[i] % mod);
    	for (i = 1; i <= n; i++)
    		printf("%d ", a[i]);
    	putchar('
    ');
    	fclose(stdin);
    	fclose(stdout);
    	return 0;
    }
    
  • 相关阅读:
    Python之二维数组(list与numpy.array)
    too many values to unpack
    python 寻找可迭代(list)的目标元素的下表方法
    zip函数
    map函数
    Sokcet代码错误类型
    PL-VIO Docker测试
    如何检索国外博士论文
    EuRoc V203数据集的坑
    Tracking of Features and Edges
  • 原文地址:https://www.cnblogs.com/cyf32768/p/15240240.html
Copyright © 2011-2022 走看看