Author: lllyouo
Date: 20250320
tag: 次小生成树、树上倍增、LCA
link: https://www.luogu.com.cn/problem/P4180问题描述
分析
略
参考代码
cpp
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, M = 3e5 + 10;
int h[N], e[M], w[M], ne[M], idx;
int n, m, t;
int d[N], f[N][25], g[N][25][2], p[N];
struct NODE {
int u, v, w;
bool used;
bool operator< (const NODE& x) {
return w < x.w;
}
} edge[M];
int find(int x) {
if (x == p[x]) return x;
return p[x] = find(p[x]);
}
void add(int a, int b, int c) {
e[idx] = b;
w[idx] = c;
ne[idx] = h[a];
h[a] = idx++;
}
void merge(int a[2], int b[2], int c[2]) {
if (a[0] > b[0]) c[1] = max(b[0], a[1]);
else if (a[0] < b[0]) c[1] = max(a[0], b[1]);
else c[1] = max(a[1], b[1]);
// 注意本行
c[0] = max(a[0], b[0]);
}
void bfs() {
queue<int> q;
q.push(1);
d[1] = 1;
while (q.size()) {
int x = q.front();
q.pop();
for (int i = h[x]; i != -1; i = ne[i]) {
int y = e[i];
if (d[y]) continue;
d[y] = d[x] + 1;
f[y][0] = x;
g[y][0][0] = w[i];
for (int k = 1; k <= t; k++) {
f[y][k] = f[f[y][k - 1]][k - 1];
merge(g[y][k - 1], g[f[y][k - 1]][k - 1], g[y][k]);
}
q.push(y);
}
}
}
int lca(int x, int y, int dis[2]) {
dis[0] = dis[1] = 0;
if (d[x] < d[y]) swap(x, y);
for (int i = t; i >= 0; i--) {
if (d[f[x][i]] >= d[y]) {
merge(dis, g[x][i], dis);
x = f[x][i];
}
}
if (x == y) return x;
for (int i = t; i >= 0; i--) {
if (f[x][i] != f[y][i]) {
merge(dis, g[x][i], dis);
merge(dis, g[y][i], dis);
x = f[x][i];
y = f[y][i];
}
}
merge(dis, g[x][0], dis);
merge(dis, g[y][0], dis);
return f[x][0];
}
long long kruskal() {
for (int i = 1; i <= n; i++) p[i] = i;
sort(edge, edge + m);
long long sum = 0;
for (int i = 0; i < m; i++) {
int fu = find(edge[i].u);
int fv = find(edge[i].v);
if (fu != fv) {
p[fu] = fv;
edge[i].used = true;
sum += edge[i].w;
int u = edge[i].u, v = edge[i].v, w = edge[i].w;
add(u, v, w); add(v, u, w);
}
}
return sum;
}
int main() {
memset(h, -1, sizeof h);
scanf("%d%d", &n, &m);
for (int i = 0; i < m; i++) {
int a, b, c; scanf("%d%d%d", &a, &b, &c);
edge[i] = {a, b, c, false};
}
// 求解最小生成树并建树
long long sum = kruskal();
// 预处理
t = int(log(n) / log(2)) + 1;
bfs();
int det = 1e9;
for (int i = 0; i < m; i++) {
if (!edge[i].used) {
int dis[2] = {0, 0}, u = edge[i].u, v = edge[i].v, w = edge[i].w;
if (u == v) continue;
lca(u, v, dis);
if (dis[0] < w) det = min(det, w - dis[0]);
else if (dis[1]) det = min(det, w - dis[1]);
}
}
printf("%lld\n", sum + det);
return 0;
}