题解
对于50个k都维护一个(i^k)前缀和即可
查询的时候就是查询一段连续的区间和,再加上根节点的
代码
#include <bits/stdc++.h>
#define fi first
#define se second
#define pii pair<int,int>
#define space putchar(' ')
#define enter putchar('
')
#define mp make_pair
#define MAXN 300005
#define pb push_back
//#define ivorysi
using namespace std;
typedef long long int64;
typedef unsigned int u32;
template<class T>
void read(T &res) {
res = 0;char c = getchar();T f = 1;
while(c < '0' || c > '9') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
res = res * 10 + c - '0';
c = getchar();
}
res *= f;
}
template<class T>
void out(T x) {
if(x < 0) {putchar('-');x = -x;}
if(x >= 10) {
out(x / 10);
}
putchar('0' + x % 10);
}
const int MOD = 998244353;
int a[55][300005],N,M;
int dep[MAXN],st[MAXN * 2][20],len[MAXN * 2],pos[MAXN],idx;
struct node {
int to,next;
}E[MAXN * 2];
int head[MAXN],sumE;
void add(int u,int v) {
E[++sumE].to = v;
E[sumE].next = head[u];
head[u] = sumE;
}
int inc(int a,int b) {
return a + b >= MOD ? a + b - MOD : a + b;
}
int mul(int a,int b) {
return 1LL * a * b % MOD;
}
int min_dep(int a,int b) {
return dep[a] < dep[b] ? a : b;
}
void dfs(int u,int fa) {
pos[u] = ++idx;
st[idx][0] = u;
for(int i = head[u] ; i ; i = E[i].next) {
int v = E[i].to;
if(v != fa) {
dep[v] = dep[u] + 1;
dfs(v,u);
st[++idx][0] = u;
}
}
}
int lca(int u,int v) {
u = pos[u];v = pos[v];
if(u > v) swap(u,v);
int l = len[v - u + 1];
return min_dep(st[u][l],st[v - (1 << l) + 1][l]);
}
void Init() {
read(N);
int u,v;
for(int i = 1 ; i < N ; ++i) {
read(u);read(v);add(u,v);add(v,u);
}
for(int i = 1 ; i <= N ; ++i) {
a[1][i] = i;
for(int j = 1 ; j <= 50 ; ++j) {
a[j + 1][i] = mul(a[j][i],i);
a[j][i] = inc(a[j][i - 1],a[j][i]);
}
}
dfs(1,0);
for(int j = 1 ; j <= 19 ; ++j) {
for(int i = 1 ; i <= idx ; ++i) {
if(i + (1 << j) - 1 > idx) break;
st[i][j] = min_dep(st[i][j - 1],st[i + (1 << j - 1)][j - 1]);
}
}
for(int i = 2 ; i <= idx ; ++i) len[i] = len[i / 2] + 1;
}
void Solve() {
int u,v,k,ans;
read(M);
for(int i = 1 ; i <= M ; ++i) {
read(u);read(v);read(k);
int f = lca(u,v);
ans = 0;
ans = inc(ans,inc(a[k][dep[u]],MOD - a[k][dep[f]]));
ans = inc(ans,inc(a[k][dep[v]],MOD - a[k][dep[f]]));
if(dep[f]) ans = inc(ans,inc(a[k][dep[f]],MOD - a[k][dep[f] - 1]));
out(ans);enter;
}
}
int main() {
#ifdef ivorysi
freopen("7.in","r",stdin);
#endif
Init();
Solve();
}