백준11933 공장들

문제 링크

  • http://icpc.me/11933

문제 출처

  • 2014 JOIOC 1번

사용 알고리즘

  • 트리 압축

시간복잡도

  • $O((sumS + sumT + N) \log N)$

풀이

잘 생각해보면, 쿼리로 주어진 정점들과 그들의 LCA만 봐도 답을 구할 수 있다는 것을 알 수 있습니다. 트리 압축을 해줍시다.

압축된 트리에서 DFS를 돌면사, 각 정점마다 {현재 정점에서 가장 가까운 회사 U의 공장까지의 거리, 현재 정점에서 가장 가까운 회사 V의 공장까지의 거리}를 구해주면 문제를 풀 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
using namespace std;

typedef long long ll;
typedef pair<ll, ll> p;

int n, q;
vector<p> g[505050];
int par[22][505050], dep[505050]; ll dst[505050];
int in[505050], out[505050], pv;

void dfs(int v = 1, int b = -1){
    in[v] = ++pv;
    for(auto i : g[v]) if(i.x != b){
            par[0][pv+1] = in[v];
            dep[pv+1] = dep[in[v]] + 1;
            dst[pv+1] = dst[in[v]] + i.y;
            dfs(i.x, v);
        }
    out[in[v]] = pv;
}

int lca(int u, int v){
    if(dep[u] < dep[v]) swap(u, v);
    int diff = dep[u] - dep[v];
    for(int i=0; diff; i++){
        if(diff & 1) u = par[i][u];
        diff >>= 1;
    }
    if(u == v) return u;
    for(int i=21; ~i; i--) if(par[i][u] != par[i][v]) u = par[i][u], v = par[i][v];
    return par[0][u];
}

int color[505050];
vector<int> vertex;

ll ans = 1e18;
p f(int v){
    ll a = color[v] == 1 ? 0 : 1e18; // min dst(v, S_i)
    ll b = color[v] == 2 ? 0 : 1e18; // min dst(v, T_i)
    color[v] = 0;
    while(pv < vertex.size() && vertex[pv] <= out[v]){
        int nxt = vertex[pv]; pv++;
        auto t = f(nxt);
        a = min(a, t.x + dst[nxt] - dst[v]);
        b = min(b, t.y + dst[nxt] - dst[v]);
    }
    ans = min(ans, a+b);
    return {a, b};
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n >> q;
    for(int i=1; i<n; i++){
        int s, e, x; cin >> s >> e >> x; s++; e++;
        g[s].emplace_back(e, x);
        g[e].emplace_back(s, x);
    }
    dfs();
    for(int i=1; i<22; i++) for(int j=1; j<=n; j++) par[i][j] = par[i-1][par[i-1][j]];

    while(q--){
        int s, t; cin >> s >> t;
        vertex.clear();
        for(int i=0; i<s; i++){
            int x; cin >> x; x++; vertex.push_back(in[x]); color[in[x]] = 1;
        }
        for(int i=0; i<t; i++){
            int x; cin >> x; x++; vertex.push_back(in[x]); color[in[x]] = 2;
        }
        compress(vertex);
        vector<int> lcaa;
        for(int i=1; i<vertex.size(); i++) lcaa.push_back(lca(vertex[i-1], vertex[i]));
        for(auto i : lcaa) vertex.push_back(i);
        compress(vertex);
        ans = 1e18; pv = 0; f(1);
        cout << ans << "\n";
    }
}