Heavy Light Decomposition

서론

약 1년 전에 이런 글이런 글을 쓴 적이 있는데, 최근에 hld 문제를 많이 풀어보면서 훨씬 쉬운 구현들을 터득했습니다.
이 글에서는 HLD의 원리와 매우 쉬운 구현 방법에 대해 다룹니다.

우리가 처리해야하는 2개의 쿼리

우리는 정점 100,000개로 이루어진 트리에서 아래 2가지 쿼리를 총 100,000번 처리해야 합니다.

  1. Update v w : 정점 v의 가중치에 w를 더해준다.
  2. Query s e : s에서 e로 가는 경로에 있는 모든 정점의 가중치의 합을 출력한다.

선형이라면 2EZ

1, 2번 쿼리를 트리가 아닌 배열에서 처리한다고 생각해보면, BOJ2042 구간 합 구하기 문제와 동일한 문제가 됩니다.
세그먼트 트리/펜윅 트리 등을 이용해 $O(Q log N)$ 시간에 풀 수 있습니다.

트리에서는 불가능

트리는 선형이 아니기 때문에 세그먼트 트리같은 자료구조를 사용할 수 없습니다. 트리에서도 세그먼트 트리를 사용할 수 있는 형태로 만들기 위해 트리를 적당히 잘라서 여러 개의 체인들로 만들어 줄 것 입니다. 각 체인 안에서의 쿼리는 세그먼트 트리를 사용하면 $O(log N)$에 처리할 수 있습니다.
만약 모든 (u, v)쌍에 대해 u에서 v로 가는 경로에 $O(log N)$개의 체인만 존재하게 할 수 있다면 트리에서 경로에 대한 쿼리를 $O(log N)$번의 구간 쿼리로 바꿀 수 있고, 세그먼트 트리 등의 자료구조를 이용하면 $O(log^2 N)$에 처리할 수 있습니다.

HLD가 뭔데?

HLD는 Heavy Light Decomposition의 약자입니다. 트리의 간선들을 Heavy Edge(무거운 간선)와 Light Edge(가벼운 간선)로 구분하는 것을 의미합니다.

보통 무겁다/가볍다를 “무게”라는 척도를 이용해 구분하듯이, Heavy Edge와 Light Edge는 “서브 트리의 크기”를 기준으로 구분합니다.
부모 정점 u에서 자식 정점 v로 가는 간선 (u, v)가 있을 때 v의 서브 트리 크기가 u의 서브 트리 크기의 절반 이상인 경우(sz[son] ≥ sz[parent]/2) 그 간선을 heavy edge라고 하고, 나머지 간선들은 light edge라고 합니다. 한 정점에서 내려가는 heavy edge는 최대 한 개만 존재해야 합니다.

이렇게 heavy edge와 light edge를 잘 구분해놓으면 좋은 점은, light edge를 타고 올라가면 무조건 트리의 크기가 2배 이상이 됩니다. 그러면 당연히 어떤 정점에서 루트로 갈 때 최대 $O(log N)$개의 light edge만 거치게 됩니다.
정점 u에서 정점 v로 가는 경로는 루트, 혹은 루트보다 아래에 있는 노드를 거쳐서 가기 때문에 최대 $2*O(log N) = O(log N)$개의 light edge만 거치게 됩니다.

보통 HLD를 코드로 구현할 때는 구현의 편의를 위해 sz[son] ≥ sz[parent]/2 대신 sz[son]이 가장 큰 간선을 heavy edge로 잡습니다. 이렇게 해도 복잡도 등의 분석은 크게 달라지지 않습니다.

이 트리에서 heavy edge들을 표시하면 아래 그림과 같이 됩니다.

위에서 이야기했듯이, 각 정점에서 아래로 뻗어나가는 heavy edge는 최대 한 개이기 때문에 인접한 heavy edge들은 한 개의 체인으로 묶어줄 수 있습니다. light edge들은 그 자체를 하나의 체인으로 보면 됩니다.

인접한 heavy edge들을 체인으로 묶어주었으니, 이것들에 대해서는 세그먼트 트리등의 자료구조를 사용할 수 있고, 모든 경로에서는 light edge를 최대 $O(log N)$개 거치기 때문에 당연히 $O(log N)$개의 체인만 보면 됩니다.

구현은 어떻게 하죠?

구현은 정말 다양한 방법이 존재합니다. 그러나 맞는 구현은 단 한 가지라는 말이 있을 정도로 많이 사용하고 있고, 그만큼 매우 간단한 구현 방법을 소개하고자 합니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
int sz[MAXV], dep[MAXV], par[MAXV], top[MAXV], in[MAXV], out[MAXV];
vector<int> g[MAXV];
/*
sz[i] = i를 루트로 하는 서브트리의 크기
dep[i] = i의 깊이
par[i] = i의 부모 정점
top[i] = i가 속한 체인의 가장 위에 있는 정점
in[i], out[i] = dfs ordering
g[i] = i의 자식 정점
*/

void dfs1(int v = 1){
	sz[v] = 1;
	for(auto &i : g[v]){
		dep[i] = dep[v] + 1; par[i] = v;
		dfs1(i); sz[v] += sz[i];
		if(sz[i] > sz[g[v][0]]) swap(i, g[v][0]);
	}
}

void dfs2(int v = 1){
	in[v] = ++pv;
	for(auto i : g[v]){
		top[i] = i == g[v][0] ? top[v] : i;
		dfs2(i);
	}
	out[v] = pv;
}

코드를 하나씩 살펴봅시다.

dfs1에서는 sz, dep, par배열을 채워주고 있습니다. 그러면서 동시에 서브트리가 가장 큰 자식을 맨 앞으로 보내는 역할을 해주고 있습니다.(if(sz[i] > sz[g[v][0]]) swap(i, g[v][0]);)
dfs1가 끝난다면 v에서 뻗어나가는 heavy edge는 $(v, g[v][0])$일 것입니다.

dfs2에서는 in, out 배열을 채워주고 있습니다. 이는 dfs를 돌면서 i번에 들어가는 시점을 in[i]에, i번에서 빠져나가는 시점을 out[i]에 저장합니다. dfs1에서 heavy edge를 인접리스트의 가장 앞으로 옮겨주었기 때문에, 인접한 heavy edge에 속한 정점들은 dfs ordering 상에서도 인접합니다!
top[i]는 체인의 가장 위에 있는 정점인데, 만약 i가 v의 0번째 자식이라면 같은 체인에 속하므로 top[v]를 물려받고, 그렇지 않다면 새로운 체인이 시작하는 것이기 때문에 i로 설정해주면 됩니다.

각 정점의 in값은 유일합니다. 또한, dfs1과 dfs2를 이용해 같은 체인에 속한 정점들은 in값도 인접하게 만들어주었기 때문에 세그먼트 트리에서 각 정점을 관리하는 인덱스를 in[i]로 해줄 수 있습니다.
만약 어떤 정점 v부터 v가 속한 체인의 가장 위에 있는 정점까지의 구간을 알고 싶다면 [ in[top[v]], in[v] ]​구간을 보면 됩니다.

dfs2에서 in과 더불어 out까지 구해놓았기 때문에 서브트리에 대한 쿼리도 처리를 해줄 수 있습니다!

Update 처리

한 정점에 대해서만 갱신을 하는 경우에는 seg.update(in[v], w) 형식으로 해주면 됩니다.
만약 경로에 대해 갱신을 할 때는 바로 아래에서 다룰 Query 처리를 참고하셔서 Lazy Propagation을 잘 섞어주시면 됩니다.

Query 처리

Query만 처리하면 끝납니다!

경로에 대한 쿼리를 처리하는 기본적인 아이디어는 경로를 여러 개의 체인으로 나눠서, 각 체인에 대해 쿼리를 날려준 뒤 모두 합치는 것입니다.

만약 두 정점이 같은 체인에 속한다면 아래 코드처럼 단순히 세그먼트 트리에 쿼리를 한 번 날리는 것으로 끝납니다.

1
2
3
if(dep[a] > dep[b]) swap(a, b);
ret += seg.query(1, 1, n, in[a], in[b]);
return ret;

만약 두 정점이 서로 다른 체인에 속한다면, 같은 체인에 속할 때까지 체인을 타고 올라가야합니다.

u에서 v로 가는 경로를 처리할 때, u, v의 lca를 기준으로 나눠서 보면 이해하기 쉽습니다.
경로를 처리하기 위해서는 lca까지 모두 봐야하고, u v가 서로 다른 체인에 있으면 각자 체인을 타고 쭉쭉 올라와서 lca와 같은 체인에서 만나는 결말이 나와야 합니다.
그것을 이루기 위해서, u와 v 중 더 아래에 있는 정점 x를 선택해서 top[x]부터 x, 즉 x부터 x가 속한 체인의 끝까지 모두 쿼리를 처리해준 뒤, x를 par[st]로 올려줍니다. 이런 방식으로 체인들을 하나씩 떼어나가면 결국 마지막에는 lca와 같은 체인에서 만나게 되고, 같은 체인에 속한 쿼리는 쉽게 해결할 수 있습니다.
구현은 아래와 같이 하면 됩니다.

1
2
3
4
5
6
7
8
9
10
11
12
int query(int a, int b){
    int ret = 0;
    while(top[a] != top[b]){
        if(dep[top[a]] < dep[top[b]]) swap(a, b);
        int st = top[a];
        ret += seg.query(in[st], in[a]);
        a = par[st];
    }
    if(dep[a] > dep[b]) swap(a, b);
    ret += seg.query(in[a], in[b]);
    return ret;
}

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>
using namespace std;

struct Seg{
    int tree[1 << 18];
    int sz = 1 << 17;

    void update(int x, int v){
        x |= sz; tree[x] += v;
        while(x >>= 1){
            tree[x] = tree[x << 1] + tree[x << 1 | 1];
        }
    }

    int query(int l, int r){
        l |= sz, r |= sz;
        int ret = 0;
        while(l <= r){
            if(l & 1) ret += tree[l++];
            if(~r & 1) ret += tree[r--];
            l >>= 1, r >>= 1;
        }
        return ret;
    }
}seg;

int sz[101010], dep[101010], par[101010], top[101010], in[101010], out[101010];
vector<int> g[101010];
vector<int> inp[101010]; //입력 / 양방향 그래프

int chk[101010];
void dfs(int v = 1){
	chk[v] = 1;
	for(auto i : inp[v]){
		if(chk[i]) continue;
		chk[i] = 1;
		g[v].push_back(i);
		dfs(i);
	}
}

void dfs1(int v = 1){
	sz[v] = 1;
	for(auto &i : g[v]){
		dep[i] = dep[v] + 1; par[i] = v;
		dfs1(i); sz[v] += sz[i];
		if(sz[i] > sz[g[v][0]]) swap(i, g[v][0]);
	}
}

int pv;
void dfs2(int v = 1){
	in[v] = ++pv;
	for(auto i : g[v]){
		top[i] = i == g[v][0] ? top[v] : i;
		dfs2(i);
	}
	out[v] = pv;
}

void update(int v, int w){
    seg.update(in[v], w);
}

int query(int a, int b){
    int ret = 0;
    while(top[a] ^ top[b]){
        if(dep[top[a]] < dep[top[b]]) swap(a, b);
        int st = top[a];
        ret += seg.query(in[st], in[a]);
        a = par[st];
    }
    if(dep[a] > dep[b]) swap(a, b);
    ret += seg.query(in[a], in[b]);
    return ret;
}

int main(){
    ios_base::sync_with_stdio(0); cin.tie(0);
    int n, q; cin >> n >> q; //정점 개수, 쿼리 개수
    for(int i=1; i<n; i++){
        int s, e; cin >> s >> e;
        inp[s].push_back(e);
        inp[e].push_back(s);
    }
    dfs(); dfs1(); dfs2();
    while(q--){
        //1 v w : update v w
        //2 s e : query s e
        int op, a, b; cin >> op >> a >> b;
        if(op == 1) update(a, b);
        else cout << query(a, b) << "\n";
    }
}

연습 문제

BOJ13510 트리와 쿼리1 : 정점이 아닌 간선에 대한 쿼리입니다. 각 정점에서 위로 뻗어나가는 간선은 최대 하나라는 점을 이용해, 각 정점에 현재 정점에서 위로 올라가는 간선의 값을 저장하면 됩니다.

BOJ2927 남극탐험 : 경로 갱신/경로 쿼리이므로 Lazy Propagation을 쓰면 됩니다.

BOJ17429 국제 메시 기구 : 경로와 서브트리에 대한 쿼리, 덧셈과 곱셈 쿼리를 조합해 총 6개의 쿼리가 있는 문제입니다. in, out배열을 이용해 서브트리에 대한 쿼리를 잘 처리할 수 있습니다. mod가 2의 32승이므로 unsigned long long을 써야합니다.