Skip to content

最近公共祖先的定义

给定一颗有根树,若节点 z 既是节点 x 的祖先又是节点 y 的祖先,则称 zxy 的最近公共祖先,在 xy 的所有公共祖先中,如果 z 的深度最大,则称 zxy 的最近公共祖先(Lowest Common Ancestor,LCA),记为 LCA(x,y)

最近公共祖先的性质

  1. 前序遍历中,LCA(S) 出现在所有 S 中元素之前,后序遍历中,LCA(S) 出现在所有 S 中元素之后。
  2. 两点间的最近公共祖先必定处在树上两点间的最短路径上。
  3. d(x,y)=h(x)+h(y)2h(LCA(x,y)),其中 d(x,y) 表示 xy 的距离,h(x) 表示 x 的深度。

最近公共祖先的求法

1 向上标记法

x 向上走到根节点,并标记所有经过的节点。

y 向上走到根节点,遇到的第一个被标记的节点即为 LCA(x,y)

或是先将 xy 调整到同一深度,再同步向根节点走,当走到同一个节点时,该节点即为 LCA(x,y)

时间复杂度:O(n)

2 树上倍增法

fx,i 表示 x 向上跳 2i 步到达的节点,则 f(x,i)=f(f(x,i1),i1) 其中 i[1,logn]。特别地,如果该节点不存在,则令 f(x,i)=0,初始时,f(x,0)=fa(x)

在预处理 f 数组的同时求解 d 数组,其中 d(i) 表示 i 的深度。预处理的时间复杂度为 O(nlogn)

之后,对于任意 x,yLCA(x,y) 可以通过以下步骤求得:

  1. d(x)>=d(y),利用二进制拆分的思想,将 x 向上调整到与 y 同一深度。
  2. 若此时 x=y,则 x 即为 LCA(x,y)
  3. 否则,将 xy 同时向上调整,并保持深度一致且二者不相会。此时 xy 必定只差一步相会,即 LCA(x,y)=f(x,0)

每次查询的时间复杂度为 O(logn)

3 离线 Tarjan 算法

Tarjan 算法一种离线算法,即把所有询问一次性读入,一次性计算后再统一输出。

其本质上是使用并查集对向上标记法的优化。时间复杂度为 O(n+m)

在深度优先遍历的任意时刻,树中的节点分为三类:

  1. 已被访问完毕并且回溯的节点,给这些节点标记一个整数 2
  2. 已被访问但尚未回溯的节点,这些节点就是当前正在访问的节点及该节点的祖先节点,给这些节点标记一个整数 1
  3. 尚未被访问的节点,给这些节点标记一个整数 0

对于正在访问的节点 x,它到根节点的路径已经被标记为 1,若节点 y 是第一类节点,则 LCA(x,y) 就是从 y 向上走到根节点遇到的第一个标记为 1 的节点。

此时,可以使用并查集优化该过程。

4 在线 RMQ 算法

对一个有根树进行深度优先遍历,得到该有根树的欧拉序列,即递归到该节点时记录节点编号,回溯时再记录一次节点编号。欧拉序列的长度为 2n1

对于树中的两个节点 u,v,其最近公共祖先 LCA(u,v) 就是 [f(u),f(v)] 中的最小值。其中 f(x) 表示节点 x 在欧拉序列中首次出现的位置。

接下来问题就转化为了区间最小值问题,可以使用 RMQ 算法解决。时间复杂度为 O(nlogn)

经典例题

  1. 洛谷 P3379. 最近公共祖先

树上倍增法

cpp
#include <bits/stdc++.h>
using namespace std;

const int N = 500010, M = 2 * N;
int h[N], e[M], ne[M], idx;
int n, m, s, t;
int d[N], f[N][25];

void add(int a, int b) {
    e[idx] = b;
    ne[idx] = h[a];
    h[a] = idx++;
}

void bfs(int s) {
    queue<int> q;
    q.push(s);
    d[s] = 1;

    while (q.size()) {
        int x = q.front();
        q.pop();

        for (int i = h[x]; i != -1; i = ne[i]) {
            int y = e[i];
            if (d[y]) continue;
            d[y] = d[x] + 1;
            f[y][0] = x;
            for (int k = 1; k <= t; k++) {
                f[y][k] = f[f[y][k - 1]][k - 1];
            }
            q.push(y);
        }
    }
}

int lca(int x, int y) {
    if (d[x] < d[y]) swap(x, y);
    for (int i = t; i >= 0; i--) {
        if (d[f[x][i]] >= d[y]) x = f[x][i];
    }
    if (x == y) return x;

    for (int i = t; i >= 0; i--) {
        if (f[x][i] != f[y][i]) {
            x = f[x][i];
            y = f[y][i];
        }
    }

    return f[x][0];
}

int main() {
    memset(h, -1, sizeof h);
    scanf("%d%d%d", &n, &m, &s);
    for (int i = 1; i < n; i++) {
        int a, b; scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }

    t = int(log(n) / log(2)) + 1;
    bfs(s);

    while(m--) {
        int a, b; scanf("%d%d", &a, &b);
        printf("%d\n", lca(a, b));
    }

    return 0;
}

离线 tarjan 算法(注意在本题中该算法存在爆栈风险)

cpp
#include <bits/stdc++.h>
using namespace std;

typedef pair<int, int> PII;
const int N = 500010, M = 2 * N;
int h[N], e[M], ne[M], idx;
int n, m, s;
int p[N], d[N], st[N];
vector<PII> query[N];
int res[M];

void add(int a, int b) {
    e[idx] = b;
    ne[idx] = h[a];
    h[a] = idx++;
}

void dfs(int u, int fa) {
    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        d[j] = d[u] + 1;
        dfs(j, u);
    }
}

int find(int x) {
    if (p[x] != x) p[x] = find(p[x]);
    return p[x];
}

void tarjan(int u) {
    st[u] = 1;
    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        if (!st[j]) {
            tarjan(j);
            p[j] = u;
        }
    }

    for (auto item : query[u]) {
        int v = item.first, id = item.second;
        if (st[v] == 2) res[id] = find(v);
    }

    st[u] = 2;
}

int main() {
    memset(h, -1, sizeof h);

    scanf("%d%d%d", &n, &m, &s);
    for (int i = 1; i < n; i++) {
        int a, b; scanf("%d%d", &a, &b);
        add(a, b); add(b, a);
    }
    for (int i = 1; i <= m; i++) {
        int a, b; scanf("%d%d", &a, &b);
        if (a != b) {
            query[a].push_back({b, i});
            query[b].push_back({a, i});
        }
    }

    for (int i = 1; i <= n; i++) p[i] = i;

    dfs(s, -1);
    tarjan(s);

    for (int i = 1; i <= m; i++) printf("%d\n", res[i]);

    return 0;
}