Address
Solution
因为加是单点加,乘是全体乘,所以考虑计算后面的乘对前面的加的影响。
也就是说,对于某次执行 (T_j=1) 的操作 (a_p+=v),设在它之后执行的 (T_j=2) 的操作的 (prod V_j=x)。那么计算最终答案的时候,只要把 (a_p+=v imes x) 即可。
对于 (T_j=3) 的操作,题目说保证不会出现递归(即不会直接或间接地调用本身)。因此建一张 DAG,如果函数 (u) 直接调用了函数 (v),那么连一条 (u→v) 的边。方便起见,再建一个点 (m+1),向 (Q) 个 (f_i) 都连一条边。
接下来开始暴力,我们记一个 (prod),表示当前访问过的 (T_j=2) 节点的 (prod V_j)。(重复访问就重复计算)
从 (m+1) 开始 DFS(注意出边的顺序要反过来,因为是后面的乘对前面的加的影响)。DFS 到 (u) 的时候,如果 (T_j=2),(prod imes=V_j),如果 (T_j=1),(a_{P_j}+=V_j imes prod),如果 (T_j=3),就什么都不做。
怎么优化这个暴力?
考虑对于一个点 (u),它连向的点分别为 (v_1,v_2,...,v_k)。那么 DFS 到 (u) 之后,设当前的 (prod) 为 (s),接下来肯定是 DFS (v_1),那么执行完 DFS (v_1),准备 DFS (v_2) 的时候,(prod) 是多少?
预处理出 (dp_u) 表示从 (u) 开始 DFS,经过的所有 (T_j=2) 节点的 (prod V_j),按拓扑序倒序转移即可。
那么上述的 (prod) 就是 (dp_{v_1} imes s),以此类推,准备 DFS (v_i) 的时候,(prod) 就是 (prod_{j=1}^{i-1}dp_{v_j} imes s)。
我们可以这样描述这个 DFS:从 (u) 开始,带着大小为 (prod) 的标记走下去,接下来,对于每个 (v_i),带着大小为 (prod_{j=1}^{i-1}dp_{v_j} imes s) 的标记走下去。也就是说,我们不用把 (v_1sim v_{i-1}) 都 DFS 一遍,就可以知道 (v_i) 的标记大小,它仅仅取决于所有连向它的 (u)。
我们记 (tag_u) 表示点 (u) 的标记大小。这个 (tag) 有什么用呢?我们发现所有的 (T_jin{1,2}) 的 (j) 都是底层节点,没有出边,所以如果 (T_j=1),我们求出 (tag_j) 之后,直接让 (a_{P_j}+=V_j imes tag_j) 即可。
根据上述分析,对于一个点 (v),只要知道所有连向它的 (u) 的 (tag_u),即可用形如 (tag_v=sum_{u→v}tag_u imes prod_{xin pre(u,v)}dp_x) 的式子求出 (tag_v)。按照拓扑序转移即可。
注意把所有 (m+1) 走不到的点和边删掉。
时间复杂度 (O(n+m+Q+sum C_j))。
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
template <class t>
inline void print(t x)
{
if (x > 9) print(x / 10);
putchar(x % 10 + 48);
}
const int N = 1e5 + 15, M = 2e6 + 15, mod = 998244353;
int adj[N], nxt[M], go[M], val[N], pos[N], typ[N], n, m, q, tag[N];
int f[N], deg[N], seq[N], cnt, a[N], num;
bool vis[N];
inline void add(int &x, int y)
{
(x += y) >= mod && (x -= mod);
}
inline void link(int x, int y)
{
nxt[++num] = adj[x];
adj[x] = num;
go[num] = y;
deg[y]++;
}
inline void dfs(int u)
{
if (vis[u]) return;
vis[u] = 1;
for (int i = adj[u]; i; i = nxt[i]) dfs(go[i]);
}
inline void topsort()
{
queue<int>q;
int i, j;
q.push(m + 1);
seq[cnt = 1] = m + 1;
while (!q.empty())
{
int u = q.front();
q.pop();
for (i = adj[u]; i; i = nxt[i])
{
int v = go[i];
if (!vis[v]) continue;
deg[v]--;
if (!deg[v]) q.push(v), seq[++cnt] = v;
}
}
for (i = cnt; i >= 1; i--)
{
int u = seq[i];
for (j = adj[u]; j; j = nxt[j])
{
int v = go[j];
f[u] = (ll)f[u] * f[v] % mod;
}
}
}
inline void solve()
{
int i, j;
for (i = 1; i <= cnt; i++)
{
int u = seq[i], pre = 1;
for (j = adj[u]; j; j = nxt[j])
{
int v = go[j];
add(tag[v], (ll)tag[u] * pre % mod);
pre = (ll)pre * f[v] % mod;
}
}
}
int main()
{
freopen("call.in", "r", stdin);
freopen("call.out", "w", stdout);
read(n);
int i, j, k, x;
for (i = 1; i <= n; i++) read(a[i]);
read(m);
for (i = 1; i <= m; i++)
{
read(typ[i]);
f[i] = 1;
if (typ[i] == 1) read(pos[i]), read(val[i]);
else if (typ[i] == 2) read(val[i]), f[i] = val[i];
else
{
read(k);
for (j = 1; j <= k; j++)
{
read(x);
link(i, x);
}
}
}
read(q);
for (i = 1; i <= q; i++)
{
read(x);
link(m + 1, x);
}
dfs(m + 1);
for (i = 1; i <= m + 1; i++)
for (j = adj[i]; j; j = nxt[j])
{
k = go[j];
if (!vis[k] || !vis[i]) deg[k]--;
}
f[m + 1] = tag[m + 1] = 1;
topsort();
solve();
for (i = 1; i <= n; i++) a[i] = (ll)a[i] * f[m + 1] % mod;
for (i = 1; i <= m; i++)
if (typ[i] == 1) add(a[pos[i]], (ll)val[i] * tag[i] % mod);
for (i = 1; i <= n; i++)
printf("%d ", a[i]);
putchar('
');
fclose(stdin);
fclose(stdout);
return 0;
}