[구간쿼리] 세그먼트 트리

Segment Tree란?

세그먼트 트리는 특정 구간의 합, 곱, 최대값, 최소값 등을 효율적으로 구하는 자료구조입니다.
이진 트리의 형태를 띄고 있으며 naive한 방식보다 훨씬 효율적으로 작업을 처리할 수 있습니다.

처리해야 할 쿼리

길이가 N인 arr이라는 배열에서 다음 두 가지 쿼리를 총 M번 수행해야 한다고 가정해봅시다.

  1. 구간 [l, r]이 주어졌을 때 해당 구간의 합을 구하기
  2. i번째 수를 v로 바꾸기

Naive 풀이

naive한 방식으로 하면
1번 쿼리는 최대 O(N)이 걸립니다.
2번 쿼리는 최대 O(1)이 걸립니다.

만약 arr배열의 누적합을 구한 다음 쿼리를 받는다고 한다면
누적합 배열을 구성하는데 최대 O(N)이 걸립니다.
1번 쿼리는 최대 O(1)이 걸립니다.
2번 쿼리는 최대 O(N)이 걸립니다.

결국 naive한 방식으로 하면 O(MN)이 걸리게 됩니다.

Segment Tree의 아이디어

세그먼트 트리를 이용하면 1, 2번 쿼리 모두 O(log N)만에 수행할 수 있습니다.
세그먼트 트리의 리프 노드는 배열의 데이터를 의미하고, 리프 노드가 아닌 다른 노드들은 왼쪽 자식과 오른쪽 자식의 합을 저장합니다.
N이 12이라면 각각의 노드가 담당하는 범위는 다음과 같습니다.

각각의 노드에 번호를 붙이면 다음과 같습니다.

이진 트리와 같이 i번 노드의 자식은 i2와 i2+1입니다.

구현

코드로 구현해봅시다.

init

일단 세그먼트 트리를 생성합시다.
배열의 크기가 N이면, 트리의 높이 H는 최대 ceil(logN)+1이 되고, 트리의 사이즈는 최대 (1«H)+1이 됩니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
typedef long long ll;
typedef vector<ll> vll;

/*
 * arr : 배열
 * tree : 세그먼트 트리
 * node의 관할 영역 : [start, end]
 */
ll init(vll &arr, vll &tree, int node, int start, int end){
    if(start == end){ //리프 노드인 경우
        return tree[node] = arr[start];
    }
    return tree[node] = init(arr, tree, node*2, start, (start+end)/2) + init(arr, tree, node*2+1, (start+end)/2+1, end);
    //리프 노드가 아니면 자식들의 합 저장
}

sum

이제 [left, right]의 합을 구해봅시다.
node가 담당하는 [start, end]구간과 구하고자 하는 [left, right]구간의 위치 관계는 다음 4가지로 분류할 수 있습니다.

  1. [left, right]가 [start, end]와 겹치지 않는 경우
  2. [left, right]가 [start, end]를 완전히 포함하는 경우
  3. [start, end]가 [left, right]를 완전히 포함하는 경우
  4. [left, right]가 [start, end]와 겹쳐있는 경우(1, 2, 3을 제외한 나머지 경우) 1번의 경우에는 겹치는 부분이 없으므로 더 이상 탐색할 필요가 없습니다. 따라서 탐색을 종료합니다.
    2번의 경우에는 구해야 하는 구간이 [left, right]고, 현재 노드의 구간인 [start, end]은 이미 모두 포함이 되고, 그 노드의 자식들도 모두 포함되기 때문에 tree[node]를 리턴합니다.
    3, 4번의 경우에는 왼쪽 서브 트리와 오른쪽 서브 트리에서 탐색을 다시 시작합니다.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    /*
     * arr : 배열
     * tree : 세그먼트 트리
     * node의 관할 영역 : [start, end]
     * 구하고자 하는 영역 : [left, right]
     */
    ll sum(vll &arr, vll &tree, int node, int start, int end, int left, int right){
     if(left>end || right<start){ //겹치는 구간이 없는 경우
         return 0;
     }
     if(left<=start && end<=right){ //[left, right]가 [start, end]를 완전히 포함하는 경우
         return tree[node];
     }
     return sum(arr, tree, node*2, start, (start+end)/2, left, right)
            + sum(arr, tree, node*2+1, (start+end)/2+1, end, left, right);
     //[start, end]가 [left, right]를 완전히 포함하거나
     //두 구간이 겹쳐 있는 경우
     //왼쪽 서브 트리와 오른쪽 서브 트리에서 다시 탐색 시작
    }
    

update

마지막으로 중간에 있는 특정 데이터의 값을 변경해봅시다.
값을 변경하면, 그 값을 포함하는 다른 모든 노드들의 값도 변경해주어야 합니다.
i번째 수를 v로 변경한다면, 그 수가 얼만큼 변했는지 알아야 합니다. 변한 정도를 diff라고 하면 diff = v - arr[i] 가 됩니다.
값을 변경하는 것은 두 가지 경우로 분류됩니다.

  1. [start, end]에 i가 포함
  2. 미포함 node 구간 안에 포함이 된다면 node의 값을 diff만큼 증가시켜 합을 변경해줍니다.
    만약 포함하지 않는 경우에는 자손들 중에 i가 없으므로 탐색을 중단합니다.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    /*
     * arr : 배열
     * tree : 세그먼트 트리
     * node의 관할 영역 : [start, end]
     * idx : 바꾸고자 하는 위치
     * diff : 변하는 정도
     */
    void update(vll &arr, vll &tree, int node, int start, int end, int idx, ll diff){
     if(idx<start || idx>end) return; //범위를 벗어남
     tree[node] += diff; //범위에 포함된다면 diff 더함
     if(start != end){ //리프 노드가 아니면
         update(arr, tree, node*2, start, (start+end)/2, idx, diff); //왼쪽 자식도 탐색
         update(arr, tree, node*2+1, (start+end)/2+1, end, idx, diff); //오른쪽 자식도 탐색
     }
    }
    

추천 문제

  • http://icpc.me/2042 위에서 설명한 문제입니다.
  • http://icpc.me/7578 풀이
  • http://icpc.me/2517 풀이
  • http://icpc.me/2243 세그먼트 트리를 이용해 k번째 수를 구하는 문제입니다.풀이
  • http://icpc.me/11658 2D세그를 구현하는 문제입니다. 세그트리의 각 노드에 세그트리를 넣으면 됩니다. 풀이

다음 글에서는 Lazy Propagation을 다루도록 하겠습니다.