[트리] 최소 공통 조상(LCA)

LCA란?

이 글에서는 LCA를 O(log N)만에 구하는 알고리즘을 다룹니다.

LCA의 정의를 간단하게 요약하자면, 트리 상의 정점 u, v에서 u와 v의 공통 조상 중 가장 깊이 있는 정점 이라고 할 수 있습니다.
위에서 언급했듯이 이 글에서는 LCA를 O(log N)만에 구할 것입니다.
LCA를 이용하면 트리 상에 있는 두 정점 사이의 최단 거리를 정말 빠르게 찾을 수 있습니다. 또한, HLD같은 복잡한 자료구조에도 쓰이니 잘 알아두는 것이 좋습니다.


이 그림처럼 U, V의 공통 조상 중에서 가장 깊이 있는 정점을 고를 것입니다.

구현 방법

DFS 돌 때 꽤 자주 사용하던 배열 중에 부모를 저장하던 배열이 있습니다.
보통 par[v] = v의 부모 정점 처럼 정의를 했었는데, lca를 구할 때는 par 배열을 2차원으로 확대시켜 다시 정의할 것입니다.
par[v][t] = v의 2t 번째 조상 이라고 정의를 합시다.
2t+1 = 2t + 2t라는 것은 자명합니다. par[v][t] = par[ par[v][t-1] ][t-1] 이라는 수식을 사용해 Bottom-Up DP처럼 채워주면 par 배열을 구할 수 있습니다.

이제 LCA를 구해봅시다. dep[v] = v의 깊이 라고 가정하고 설명을 하겠습니다.
dep[u] > dep[v]인 상황에서는 두 단계만 거치면 lca를 구할 수 있습니다.

  1. dep[u] > dep[v] 일 때 u를 u의 조상으로 바꾼다.
  2. u와 v를 각각 자신의 조상으로 바꾼다.

1번 과정부터 구현해봅시다.
u를 하나씩 올려준다면 최악의 경우 O(N)이 걸립니다. 조금 더 빠르게 개선시킵시다.
만약 u와 v의 깊이가 13만큼 차이가 난다고 해봅시다. 이진법으로 나타내면 1101(2)입니다. 그러므로 20, 22, 23번째 부모를 따라가면 최대 O(log N)번의 비교를 통해 dep[u]와 dep[v]를 같게 만들 수 있습니다.
dep[u]와 dep[v]가 같아졌는데, u == v라면 u를 반환해주고 끝내면 됩니다.

2번 과정을 구현해봅시다.
2번 과정을 할 차례가 되었다는 것은 dep[u] == dep[v]인 상황이라는 것입니다. 2번 과정도 하나씩 올리면 O(N)이니까 더 빠른 방법을 생각해봅시다.
par[u][t] != par[v][t] 라면 lca는 최소한 2t 이상 떨어져 있다는 것을 알 수 있습니다.
par[u][t] != par[v][t] && par[u][t+1] == par[v][t+1] 라면 lca는 2^t ~ 2^(t+1) 만큼 떨어져 있다는 것을 알 수 있습니다.
그러므로 t를 큰 값부터 하나씩 내려가면서 par[u][t] != par[v][t] 라면 u와 v를 동시에 2t 만큼 올려주면 됩니다.
루프가 다 끝났다면 u != v지만 두 개의 부모는 동일합니다. par[u][0]을 반환해주면 됩니다.

구현

lca를 구하는 문제인 백준11438번의 전체 코드는 아래와 같습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <bits/stdc++.h>
using namespace std;

const int size = 1e5+10;
vector<int> g[size];
int depth[size];
int dp[size][30];
int visit[size];
int n;

void dfs(int v, int d){
	visit[v] = 1;
	depth[v] = d;
	for(auto i : g[v]){
		if(!visit[i]){
			dp[i][0] = v;
			dfs(i, d+1);
		}
	}
}

void make_table(){
	for(int j=1; j<30; j++){
		for(int i=1; i<=n; i++){
			dp[i][j] = dp[ dp[i][j-1] ][j-1];
		}
	}
}

int lca(int u, int v){
	if(depth[u] < depth[v]) swap(u, v);
	int diff = depth[u] - depth[v];
	for(int i=0; diff; i++){
		if(diff & 1) u = dp[u][i];
		diff >>= 1;
	}
	if(u == v) return u;

	for(int i=29; i>=0; i--){
		if(dp[u][i] != dp[v][i]) u = dp[u][i], v = dp[v][i];
	}
	return dp[u][0];
}

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	cin >> n;
	for(int i=0; i<n-1; i++){
		int a, b; cin >> a >> b;
		g[a].push_back(b);
		g[b].push_back(a);
	}
	dfs(1, 0);
	make_table();

	int m; cin >> m;
	while(m--){
		int a, b; cin >> a >> b;
		cout << lca(a, b) << "\n";
	}
}

추천 문제

  • http://icpc.me/1761 dist[u] + dist[v] - 2dist[lca(u, v)]
  • http://icpc.me/3176 풀이