Author: lllyouo
Date: 20250331
tag: 最近公共祖先
link: https://www.acwing.com/problem/content/359/问题描述
分析
略
参考代码
cpp
#include <bits/stdc++.h>
using namespace std;
typedef pair<long long, int> PLLI;
const int N = 50010, M = 2 * N;
int h[N], e[M], w[M], ne[M], idx;
int n, m, t, cnt, p;
int d[N], f[N][25], army[N], son[N];
bool has[N], used[N], cover[N];
PLLI a[N];
long long l, r, dis[N];
void add(int a, int b, int c) {
e[idx] = b;
w[idx] = c;
ne[idx] = h[a];
h[a] = idx++;
}
void bfs() {
queue<int> q;
q.push(1);
d[1] = 1;
while (q.size()) {
int x = q.front();
q.pop();
for (int i = h[x]; i != -1; i = ne[i]) {
int y = e[i];
if (d[y]) continue;
d[y] = d[x] + 1;
dis[y] = dis[x] + w[i];
f[y][0] = x;
for (int k = 1; k <= t; k++) {
f[y][k] = f[f[y][k - 1]][k - 1];
}
q.push(y);
}
}
}
PLLI go(int x, long long mid) {
for (int i = t; i >= 0; i--) {
if (f[x][i] > 1 && dis[x] - dis[f[x][i]] <= mid) {
mid -= dis[x] - dis[f[x][i]];
x = f[x][i];
}
}
return {mid, x};
}
void dfs(int x, int fa) {
bool is_cover = true, is_leaf = true;
for (int i = h[x]; i != -1; i = ne[i]) {
int y = e[i];
if (y == fa) continue;
dfs(y, x);
is_cover &= cover[y];
is_leaf = false;
if (x == 1 && !cover[y]) son[++p] = y;
}
cover[x] = has[x] || (!is_leaf && is_cover);
}
bool cmp(int x, int y) {
return dis[x] < dis[y];
}
bool check(long long mid) {
memset(has, 0, sizeof has);
memset(cover, 0, sizeof cover);
memset(used, 0, sizeof used);
cnt = p = 0;
// 为军队分类
for (int i = 1; i <= m; i++) {
PLLI t = go(army[i], mid);
int pos = t.second;
long long rest = t.first;
if (rest <= dis[pos]) has[pos] = true;
else a[++cnt] = {rest - dis[pos], pos};
}
// 统计一类军队控制的节点以及根节点有哪些的孩子节点
dfs(1, -1);
// 统计二类军队是否跨过根节点
sort(a + 1, a + cnt + 1);
for (int i = 1; i <= cnt; i++) {
long long rest = a[i].first;
int pos = a[i].second;
if (!cover[pos] && rest < dis[pos]) {
cover[pos] = used[i] = true;
}
}
// 令二类军队控制根节点尚未被控制的孩子节点
sort(son + 1, son + p + 1, cmp);
for (int i = 1, j = 1; i <= p; i++) {
int s = son[i];
if (cover[s]) continue;
while (j <= cnt && (used[j] || a[j].first < dis[s])) j++;
if (j > cnt) return false;
j++;
}
return true;
}
int main() {
memset(h, -1, sizeof h);
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int a, b, c; scanf("%d%d%d", &a, &b, &c);
add(a, b, c);
add(b, a, c);
r += c;
}
t = int(log(n) / log(2)) + 1;
bfs();
scanf("%d", &m);
for (int i = 1; i <= m; i++) scanf("%d", &army[i]);
long long tmp = r;
while (l < r) {
long long mid = l + r >> 1;
if (check(mid)) r = mid;
else l = mid + 1;
}
cout << (l == tmp ? -1 : l) << endl;
return 0;
}