Skip to content
Author: loop3r
Date: 20251016
tag: 树上倍增、LCA、树型DP、虚树
link: https://codeforces.com/problemset/problem/613/D

问题描述

link

分析

参考代码

cpp
#include <bits/stdc++.h>
using namespace std;

const int N = 500010, M = 2 * N;
int n, m, k, ans, key[N];
int h[N], e[M], ne[M], idx;    // 邻接表
int fa[N][20], dep[N], siz[N]; // 倍增求 lca
int dfn_cnt, dfn[N];           // dfs 序
int s[N], top;                 // 单调栈

void add(int a, int b) {
    e[idx] = b;
    ne[idx] = h[a];
    h[a] = idx++;
}

void dfs(int x, int f) {
    dfn[x] = ++dfn_cnt, dep[x] = dep[f] + 1;
    fa[x][0] = f;
    siz[x] = 1;
    for (int i = 1; i <= 19; i++) fa[x][i] = fa[fa[x][i - 1]][i - 1];

    for (int i = h[x]; i != -1; i = ne[i]) {
        int y = e[i];
        if (y == f) continue;
        dfs(y, x);
        siz[x] += siz[y];
    }
}

int lca(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 19; i >= 0; i--) {
        if (dep[fa[x][i]] >= dep[y]) x = fa[x][i];
    }
    if (x == y) return x;

    for (int i = 19; i >= 0; i--) {
        if (fa[x][i] != fa[y][i]) {
            x = fa[x][i];
            y = fa[y][i];
        }
    }

    return fa[x][0];
}

bool cmp(int x, int y) {
    return dfn[x] < dfn[y];
}

void build() {
    idx = 0;

    sort(key + 1, key + k + 1, cmp);

    s[top = 1] = 1;
    if (key[1] != 1) s[++top] = key[1];

    for (int i = 2; i <= k; i++) {
        int l = lca(s[top], key[i]);
        while (top > 1 && dep[s[top - 1]] >= dep[l]) {
            add(s[top - 1], s[top]);
            top--;
        }
        if (l != s[top]) {
            add(l, s[top]);
            s[top] = l;
        }
        s[++top] = key[i];
    }

    while (top > 1) {
        add(s[top - 1], s[top]);
        top--;
    }
}

void dp(int x) {
    if (siz[x]) {
        for (int i = h[x]; i != -1; i = ne[i]) {
            int y = e[i];
            dp(y);
            if (siz[y]) ans++, siz[y] = 0;
        }
    } else {
        for (int i = h[x]; i != -1; i = ne[i]) {
            int y = e[i];
            dp(y);
            siz[x] += siz[y];
            siz[y] = 0;
        }
        if (siz[x] > 1) ans++, siz[x] = 0;
    }
    h[x] = -1;
}

int main() {
    cin >> n;

    memset(h, -1, sizeof h);
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }

    dfs(1, 0);

    memset(siz, 0, sizeof siz);
    memset(h, -1, sizeof h);

    cin >> m;

    while (m--) {
        cin >> k;

        bool flag = true;
        siz[1] = 0;
        for (int i = 1; i <= k; i++) {
            cin >> key[i];
            siz[key[i]] = 1;
        }

        for (int i = 1; i <= k; i++) {
            if (siz[fa[key[i]][0]]) {
                while (k) siz[key[k--]] = 0;
                puts("-1");
                flag = false;
                break;
            }
        }

        if (flag) {
            build();
            ans = 0;
            dp(1);
            cout << ans << endl;
        }
    }

    return 0;
}