Heavy Light Decomposition

이 글로 이동해주세요!!

Heavy Light Decomposition(HLD)는 트리에서 임의의 두 정점을 잇는 경로에 대한 쿼리를 O(log2N)에 수행하는 자료구조입니다.
간략하게 설명하자면, N개의 정점으로 이루어진 트리의 간선들을 Heavy-Edge와 Light-Edge로 분류한 뒤, 해당 정보를 이용해 트리의 간선들을 O(logN)개의 직선 형태로 분할(Decomposition)해 세그먼트 트리와 결합하여 쿼리를 빠르게 수행할 수 있습니다.

트리와 쿼리1 문제를 봅시다. 이 문제에서 다루는 쿼리는 두 가지입니다.

  • 1 idx val : idx번째로 입력된 간선의 가중치를 val로 변경
  • 2 u v : u번 정점에서 v번 정점으로 가는 경로에 존재하는 간선의 가중치 중 최댓값 출력

Naive 풀이

HLD를 알아보기 전에, Naive한 풀이를 먼저 생각해봅시다.
1번 쿼리는 그냥 간선의 비용을 직접 바꿔주면 됩니다. O(1)만에 가능합니다.
2번 쿼리는 u와 v의 lca를 먼저 구해서 t에 저장을 합시다. u에서 t, t에서 v까지 차례대로 순회하면서 최댓값을 찾아주면 됩니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int n; //정점 개수
int par[101010], cost[101010]; //부모, 가중치

void update(int u, int v, int c){
	if(par[u] == v) swap(u, v);
	cost[v] = c;
}

int query(int u, int v){
	int ret = 0;
	int t = lca(u, v);
	while(u != t) ret = max(ret, cost[u]), u = par[u];
	while(v != t) ret = max(ret, cost[v]), v = par[v];
	return ret;
}

Heavy Light Decomposition 전처리

글을 시작할 때, 트리의 간선들을 Heavy-Edge(무거운 간선)와 Light-Edge(가벼운 간선)으로 분류한다고 했습니다. 분류 기준은 보통 두 가지 중 하나를 씁니다.
u가 부모정점, v가 자식정점인 간선 {u, v}에 대해 v를 루트로 하는 서브트리의 크기(무게)가

  1. u를 루트로 하는 서브트리의 크기의 절반 이상일 때
  2. v의 모든 다른 형제 정점들을 각자 루트로 하는 서브트리의 크기보다 크거나 같을 때

두 가지 중 하나를 기준으로 간선을 분류합니다. 이 글에서는 2번 방법을 기준으로 분류하겠습니다.

변수를 몇 가지 정의합시다.

1
2
3
4
5
int depth[i]; //i번 정점의 깊이
int lcaT[]; //LCA Table
int par[i]; //i번 정점의 부모
int weight[i]; //i번 정점을 자식 정점으로 하는 간선의 무게
vector< pair<int, int> > g[]; //인접리스트 ( {도착점, 초기비용} )

dfs를 돌리면서 정점의 깊이와, 간선의 무게를 구해줍시다.

1
2
3
4
5
6
7
8
9
10
11
12
13
void dfs(int now, int prv){
	weight[now] = 1;
	for(auto &i : g[now]){
		int nxt = i.first, cost = i.second;
		if(nxt == prv) continue;
		depth[nxt] = depth[now] + 1;
		lcaT[0][nxt] = now;
		par[nxt] = now;
		newCost[nxt] = cost;
		dfs(nxt, now);
		weight[now] += weight[nxt];
	}
}

newCost[i]는 i를 자식 정점으로 하는 간선의 초기 비용을 의미합니다.

트리가 아래와 같이 생겼다고 합시다.

HLD로 쿼리를 처리하기 위해 무거운 간선만 골라내면 다음 사진처럼 됩니다.

인접한 무거운 간선끼리 묶어서 무거운 경로 라고 정의합시다. 편의상 무거운 경로에서 가장 위에 있는 정점에서 부모로 올라가는 간선도 경로에 포함시켜줍시다. 어떠한 무거운 경로에 포함되지 않은 간선들은 그 자체로 하나의 경로를 만들어주면 모든 간선은 정확히 하나의 경로에 속하게 됩니다.

DFS를 돌면서 간선들을 분류해봅시다. 나중에 쓸 데가 있으니 일단 DFS Ordering도 같이 매겨줍시다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
int comp[i]; //i번 간선이 속하는 무거운 경로의 번호
int head[j]; //j번 무거운 간선에서 맨 위에 있는 정점
int lab[v]; //v번 정점의 DFS Order
int
int piv1, piv2; //무거운 경로 번호, dfs order

void hld(int now, int prv){
	sort(g[now].begin(), g[now].end(), [](const p &a, const p &b){ //간선의 무게를 기준으로 내림차순 정렬
		return weight[a.first] > weight[b.first];
	});

	int heavy = -1;
	for(auto &i : g[now]){ //가장 앞에 있는 간선만 무거운 간선
		int nxt = i.first;
		if(nxt == prv) continue;
		heavy = nxt;
		comp[nxt] = comp[now];
    lab[nxt] = ++piv2;
		hld(nxt, now);
		break;
	}
	for(auto &i : g[now]){ //나머지는 가벼운 간선
		int nxt = i.first;
		if(nxt == heavy || nxt == prv) continue;
		comp[nxt] = ++piv1; //새로운 무거운 간선 생성
		head[piv1] = nxt;
    lab[nxt] = ++piv2;
		hld(nxt, now);
	}
}

트리에서 DFS Order를 매기면 대략 아래와 같은 결과가 나오게 됩니다.

사진을 보면, 같은 경로에 속하는 정점의 DFS Order는 연속되어 있다는 것을 알 수 있고, 깊이가 낮을수록 DFS Order도 작다는 것을 알 수 있습니다. 편의를 위해 이제부터 정점의 번호를 DFS Order를 이용해 정의하겠습니다.

12번 정점과 6번 정점 사이의 경로를 봅시다. 12와 6의 LCA는 1입니다.
12에서 시작해 1을 지나 6으로 가는 경로에서 마주치는 무거운 경로는 4개입니다. 또한, 이 트리의 모든 경로는 최대 4개의 무거운 경로만을 포함합니다. 이처럼 HLD는 경로상에 존재하는 무거운 경로의 개수를 O(logN)으로 유지합니다.

한 무거운 경로에 대해 쿼리를 날릴때에는, 선형이므로 세그먼트 트리를 이용하면 됩니다.
무거운 경로가 여러 개가 있더라도, 각각의 경로가 담당하는 범위는 서로 겹치지 않기 때문에 세그먼트 트리 하나를 잘 이용하면 풀 수 있습니다.

update 연산

idx번째로 입력된 간선의 가중치를 val로 변경하는 연산은 구현해야합니다.
간선의 번호를 두 정점 중 자식 정점의 DFS Order라고 정의합시다. i번째로 입력된 간선의 양쪽 두 정점을 미리 구해놓았다면, 해당 간선의 번호를 쉽게 알 수 있습니다.
해당 간선의 번호를 알았다면 세그먼트 트리에 바로 update 쿼리를 날려주면 됩니다.

1
2
3
4
5
6
int s[i], e[i]; //i번째로 입력된 간선의 시점과 종점

void update(int idx, int val){
  if(depth[s[idx]] > depth[e[idx]]) swap(s[idx], e[idx]);
	seg.update(lab[e[idx]], val);
}

query 연산

먼저 u와 v의 LCA인 t를 구한 뒤, u~t 구간에서의 최댓값과 t~v 구간에서의 최댓값을 각각 구해준 뒤, 정답을 구해주면 됩니다.

1
2
3
4
int query(int u, int v){
	int t = lca(u, v);
	return max(sub_query(u, t), sub_query(v, t));
}

sub_query의 구현 방법을 봅시다.
두 정점 s와 e가 서로 다른 무거운 경로에 속해있다면 아래의 작업을 계속 반복합니다.

  1. s와 e가 속한 무거운 경로의 맨 위 정점을 각각 ss와 ee에 저장
  2. ss와 ee 중 ss가 더 아래에 있다면 ss부터 s까지의 최댓값을 구한 뒤, s를 ss의 부모 정점으로 이동
  3. 반대의 경우에는 ee부터 e까지의 최댓값을 구한 뒤, e를 ee의 부모 정점으로 이동

s와 e가 동일한 무거운 경로 상에 존재한다면, 아래 작업을 수행합니다.

  1. s보다 e가 깊이 있다면 swap
  2. s부터 e까지 최댓값 구함
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int sub_query(int s, int e){
	int ret = 0;
	while(comp[s] != comp[e]){
		int ss = head[comp[s]], ee = head[comp[e]];
		if(depth[ss] > depth[ee]){
			ret = max(ret, seg.query(lab[ss], lab[s]));
			s = par[ss];
		}else{
			ret = max(ret, seg.query(lab[ee], lab[e]));
			e = par[ee];
		}
	}
	if(depth[e] > depth[s]) swap(s, e);
	ret = max(ret, seg.query(lab[e]+1, lab[s]));
	return ret;
}

마지막에 lab[e]+1에서 lab[s]까지 쿼리를 날리는 이유는 lab[e] 간선은 u ~ v경로상에 존재하지 않는, u와 v의 lca와 그의 부모를 잇는 간선이기 때문입니다.

최종 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include <bits/stdc++.h>
using namespace std;

struct seg{
	int tree[270000], lim;
	void init(int n){
		for(lim = 1; lim <= n; lim <<= 1);
	}
	void update(int x, int v){
		x += lim;
		tree[x] = v;
		while(x > 1){
			x >>= 1;
			tree[x] = max(tree[2*x], tree[2*x+1]);
		}
	}
	int query(int s, int e){
		s += lim;
		e += lim;
		int ret = -1e9;
		while(s < e){
			if(s%2 == 1) ret = max(ret, tree[s++]);
			if(e%2 == 0) ret = max(ret, tree[e--]);
			s >>= 1;
			e >>= 1;
		}
		if(s == e) ret = max(ret, tree[s]);
		return ret;
	}
}seg;

typedef pair<int, int> p;

int n, m;
int depth[101010], lcaT[20][101010], par[101010], weight[101010];
int comp[101010], head[101010], lab[101010];
int piv1, piv2;
vector<p> g[101010];
int newCost[101010];
int s[101010], e[101010], c[101010];

void dfs(int now, int prv){
	weight[now] = 1;
	for(auto &i : g[now]){
		int nxt = i.first, cost = i.second;
		if(nxt == prv) continue;
		depth[nxt] = depth[now] + 1;
		lcaT[0][nxt] = now;
		par[nxt] = now;
		newCost[nxt] = cost;
		dfs(nxt, now);
		weight[now] += weight[nxt];
	}
}

void hld(int now, int prv){
	sort(g[now].begin(), g[now].end(), [](const p &a, const p &b){
		return weight[a.first] > weight[b.first];
	});

	int heavy = -1;
	for(auto &i : g[now]){
		int nxt = i.first;
		if(nxt == prv) continue;
		heavy = nxt;
		comp[nxt] = comp[now];
		lab[nxt] = ++piv2;
		hld(nxt, now);
		break;
	}

	for(auto &i : g[now]){
		int nxt = i.first;
		if(nxt == heavy || nxt == prv) continue;
		comp[nxt] = ++piv1;
		head[piv1] = nxt;
		lab[nxt] = ++piv2;
		hld(nxt, now);
	}
}

int lca(int s, int e){
	if(depth[s] > depth[e]) swap(s, e);
	int diff = depth[e] - depth[s];
	for(int i=0; i<20; i++){
		if((diff >> i) & 1) e = lcaT[i][e];
	}
	for(int i=19; i>=0; i--){
		if(lcaT[i][s] != lcaT[i][e]) s = lcaT[i][s], e = lcaT[i][e];
	}
	if(s == e) return s;
	return lcaT[0][s];
}

void update(int idx, int val){
	if(depth[s[idx]] > depth[e[idx]]) swap(s[idx], e[idx]);
	seg.update(lab[e[idx]], val);
}

int sub_query(int s, int e){
	int ret = 0;
	while(comp[s] != comp[e]){
		int ss = head[comp[s]], ee = head[comp[e]];
		if(depth[ss] > depth[ee]){
			ret = max(ret, seg.query(lab[ss], lab[s]));
			s = par[ss];
		}else{
			ret = max(ret, seg.query(lab[ee], lab[e]));
			e = par[ee];
		}
	}
	if(depth[e] > depth[s]) swap(s, e);
	ret = max(ret, seg.query(lab[e]+1, lab[s]));
	return ret;
}

int query(int u, int v){
	int t = lca(u, v);
	return max(sub_query(u, t), sub_query(v, t));
}

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	cin >> n;
	seg.init(n);
	for(int i=1; i<n; i++){
		cin >> s[i] >> e[i] >> c[i];
		g[s[i]].push_back({e[i], c[i]});
		g[e[i]].push_back({s[i], c[i]});
	}

	dfs(1, 0);
	for(int i=1; i<20; i++){
		for(int j=1; j<=n; j++){
			lcaT[i][j] = lcaT[i-1][lcaT[i-1][j]];
		}
	}

	head[1] = comp[1] = lab[1] = piv1 = piv2 = 1;
	hld(1, 0);
	for(int i=2; i<=n; i++){
		seg.update(lab[i], newCost[i]);
	}
	cin >> m;
	while(m--){
		int op; cin >> op;
		if(op == 1){
			int idx, val; cin >> idx >> val;
			update(idx, val);
		}else{
			int u, v; cin >> u >> v;
			cout << query(u, v) << "\n";
		}
	}
}