Author: lllyouo
Date: 20250325
tag: 最近公共祖先、树上前缀和
link: https://www.luogu.com.cn/problem/P4427问题描述
分析
略
参考代码
cpp
#include <bits/stdc++.h>
using namespace std;
const int N = 3e5 + 10, M = 55, MOD = 998244353;
int h[N], e[2 * N], ne[2 * N], idx;
int f[N][25];
long long pre[N][M], d[N];
int n, m, t;
void add(int a, int b) {
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
int qmi(int a, int b) {
long long ans = 1;
while (b) {
if (b & 1) ans = ans * a % MOD;
a = (long long)a * a % MOD;
b >>= 1;
}
return ans % MOD;
}
void bfs() {
queue<int> q;
q.push(1);
d[1] = 1;
while (q.size()) {
int x = q.front();
q.pop();
for (int i = h[x]; i != -1; i = ne[i]) {
int y = e[i];
if (d[y]) continue;
d[y] = d[x] + 1;
f[y][0] = x;
for (int i = 1; i <= 50; i++) {
pre[y][i] = (pre[x][i] + qmi(d[x], i) + MOD) % MOD;
}
for (int k = 1; k <= t; k++) {
f[y][k] = f[f[y][k - 1]][k - 1];
}
q.push(y);
}
}
}
int lca(int x, int y) {
if (d[x] < d[y]) swap(x, y);
for (int i = t; i >= 0; i--) {
if (d[f[x][i]] >= d[y]) x = f[x][i];
}
if (x == y) return x;
for (int i = t; i >= 0; i--) {
if (f[x][i] != f[y][i]) {
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}
int main() {
memset(h, -1, sizeof h);
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int a, b; scanf("%d%d", &a, &b);
add(a, b);
add(b, a);
}
t = int(log(n) / log(2)) + 1;
bfs();
scanf("%d", &m);
while(m--) {
int a, b, c; scanf("%d%d%d", &a, &b, &c);
int p = lca(a, b);
long long s1 = (pre[a][c] + pre[b][c] + MOD) % MOD;
long long s2 = (pre[p][c] + pre[f[p][0]][c] + MOD) % MOD;
printf("%lld\n", (s1 - s2 + MOD) % MOD);
}
return 0;
}