Skip to content
Author: lllyouo
Date: 20250327
tag: 基环树、树的直径
link: https://www.acwing.com/problem/content/360/

问题描述

link

分析

参考代码

cpp
#include <bits/stdc++.h>
using namespace std;

const int N = 1e6 + 10, M = 2 * N;
int h[N], e[M], w[M], ne[M], idx;
int s[M], dfn[N], fa[N], q[M], num;
long long d[N], sum[M], ans, res;
bool vis[N];
int n, p;

void add(int a, int b, int c) {
    e[idx] = b;
    w[idx] = c;
    ne[idx] = h[a];
    h[a] = idx++;
}

// 记录该环
void get_cycle(int x, int y, int z) {
    sum[1] = z;
    while (y != x) {
        s[++p] = y;
        sum[p + 1] = w[fa[y]];
        y = e[fa[y] ^ 1];
    }
    s[++p] = x;
    for (int i = 1; i <= p; i++) {
        vis[s[i]] = true;
        // 破环成链、复制一倍
        s[p + i] = s[i];
        sum[p + i] = sum[i];
    }
    for (int i = 1; i <= 2 * p; i++) sum[i] += sum[i - 1];
}

void dfs(int x) {
    dfn[x] = ++ num;
    for (int i = h[x]; i != -1; i = ne[i]) {
        int y = e[i];
        if (!dfn[y]) {
            // 记录边的 idx
            fa[y] = i;
            dfs(y);
        } else if ((i ^ 1) != fa[x] && dfn[y] > dfn[x]) {
            // 找到一个环
            get_cycle(x, y, w[i]);
        }
    }
}

// 求解每个子树的直径
void dp(int x) {
    vis[x] = true;
    for (int i = h[x]; i != -1; i = ne[i]) {
        int y = e[i];
        if (!vis[y]) {
            dp(y);
            res = max(res, d[x] + d[y] + w[i]);
            d[x] = max(d[x], d[y] + w[i]);
        }
    }
}

int main() {
    memset(h, -1, sizeof h);
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        int b, c; scanf("%d%d", &b, &c);
        add(i, b, c); add(b, i, c);
    }
    
    for (int i = 1; i <= n; i++) {
        if (!dfn[i]) {
            p = 0, res = 0;
            dfs(i);
            for (int j = 1; j <= p; j++) dp(s[j]);
            
            int l = 1, r = 0;
            for (int j = 1; j <= 2 * p; j++) {
                while (l <= r && q[l] <= j - p) l++;
                if (l <= r) res = max(res, d[s[j]] + d[s[q[l]]] + sum[j] - sum[q[l]]);
                while (l <= r && d[s[q[r]]] - sum[q[r]] <= d[s[j]] - sum[j]) r--;
                q[++r] = j;
            }
            ans += res;
        }
    }
    printf("%lld\n", ans);
    
    return 0;
}