能切题还是挺开心的就是说。
以至于主动请缨讲题,没讲好所以来补题解了。
换个讲法。
给出数组 \(A_i\) 、\(v_i\) 和一棵树,求下面行列式:
首先如果 \(A\) 有重复那么由行列式的性质知答案为 \(0\) 。
容易想到在 lca 处处理两个点的贡献,那么
设 \(f_x\) 为 在 \(x\) 子树内行列式的值,
设 \(g_x\) 为 在 \(f_x\) 的基础上将任意一行变成全为 \(1\) 时的值 之和。
其中 \(g\) 的作用在后文会提到。
考虑对于一个点 \(x\) ,令 \(p=v_x\) ,每次加入一个子树 \(y\) 的贡献,
相当于求行列式 \(\left | \begin{array}{cc} A & B \\ C & D \end{array} \right |\) ,其中 \(A\) 是原本答案, \(D\) 是新增的子树贡献 , \(B\) 和 \(C\) 中全是 \(p\) 。
这个不是能直接算的东西。
考虑利用性质:
和
将最下面一行拆成 \(a_{ij}=a'_{ij} + p\) 的形式。
那么我们得到了两个新的行列式,其中一个存在一行全是 \(p\) (或者说把 \(p\) 提出来就是 \(1\)),另一个是原本的最后一行全部减 \(p\) 。
-
后者还是不能直接算的,我们继续做直到这个新的矩阵全部变成 \(a'_{ij}\) ,此时左边的位置都是 \(0\) 。此时贡献就是前面算出的 \(f_x \times g_y\) 。这会使得前者出现若干个类似的。
-
前者可以直接计算,即把行列式的每个位置全部减 \(p\) ,得到一个原本为 \(p\) 的位置全部是 \(0\) 的行列式,这时的贡献为之前子树的 \(f\) 全部相乘,再乘上 \(g_y \cdot p\) 。
最后得到 \(f'_x = f_x \cdot f_y + p \cdot g_y \cdot calc\) ,其中 \(calc\) 为已贡献子树的 \(f\) 之和。
这样就知道 \(g\) 有什么作用了吧。
另外要考虑 \(g\) 的转移了,
因为有一行为 \(1\) 了,多出来的 \(p\) 就可以直接减掉,变成了一个优美的行列式:将子树的行列式拼到一起,其他位置全是 \(0\) ,有一行特殊的为 \(1\)。
将每一个 \(g\) 乘上其他的 \(f\) 再求和即可。
发现因为有减 \(p\) 的存在,子树上要求的行列式变成了以下形式:
所以每次都打个标记即可。
#include<bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fd(i,a,b) for(int i=a;i>=b;--i)
#define ll long long
using namespace std;
const int N=5e5+10;
const int mod=998244353;
int n,m,a[N],v[N],tot,last[N],siz[N],ans,fa[N];
int f[N],g[N];
bool bz[N],flag=0;
int tag=0;
struct edge{
int st,en,next;
}E[N<<1];
void add(int x,int y){
E[++tot]=(edge){x,y,last[x]};
last[x]=tot;
}
void init(int x){
for(int p=last[x];p;p=E[p].next){
int y=E[p].en;
if(y==fa[x])continue;
fa[y]=x;
init(y);
siz[x]+=siz[y];
}
}
void dfs(int x){
int tmp=1;
bool flag=0;
v[x]=(v[x]-tag+mod)%mod;
tag+=v[x];tag%=mod;
if(bz[x]){
flag=1;
f[x]=v[x];
g[x]=1;
tmp=0;
}
for(int p=last[x];p;p=E[p].next){
int y=E[p].en;
if(y==fa[x] || !siz[y])continue;
dfs(y);
if(!flag){
f[x]=((ll)f[y]+(ll)v[x]*g[y])%mod;
g[x]=g[y];
flag=1;
}else{
f[x]=((ll)f[x] * f[y] % mod + (ll)v[x] * g[y] % mod * tmp % mod) % mod;
g[x]=((ll)g[x]*f[y]%mod + (ll)tmp*g[y]%mod) % mod;
}
tmp=(ll)tmp*f[y]%mod;
}
tag=(tag-v[x]+mod)%mod;
}
int main(){
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
scanf("%d%d",&n,&m);
fo(i,1,n)scanf("%d",&v[i]);
fo(i,1,m){
scanf("%d",&a[i]);
flag|=bz[a[i]];
siz[a[i]]=bz[a[i]]=1;
}
if(flag){
printf("0\n");
return 0;
}
fo(i,2,n){
int x,y;
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
init(1);
dfs(1);
printf("%d\n",f[1]);
return 0;
}