백준7578 공장

문제 링크

  • http://icpc.me/7578

문제 출처

  • 2013 KOI 지역 본선 고등부3

사용 알고리즘

  • segment tree

시간복잡도

  • O(n log n)

풀이

전형적인 inversion counting문제입니다.

예제로 주어진 상황은 아래 그림과 같습니다.


inversion counting문제는 원소는 같고 배치 순서만 다른 두 배열에서 무작위로 원소 두 개를 선택했을 때, 전후 관계가 다른 것의 쌍의 수를 의미합니다.
이 정의를 이용하여 문제를 해결할 것입니다.
위에 있는 배열은 앞에 있는 원소부터 차례로 보고, 아래에 있는 배열은 위에 있는 배열에서 선택된 원소와 같은 값을 가진 원소부터 볼 것입니다.
탐색을 하기에 앞서, 아래에 있는 배열의 모든 원소를 “방문 안함”으로 표시합시다.

첫 번째 원소를 봅시다.

아래 배열에서는 세 번째 원소와 대응되네요. 일단 “방문 했음”으로 표시해줍시다.
세 번째 원소보다 오른쪽에 있는 원소 중 “방문 했음”으로 표시된 원소의 개수는 0개입니다.

두 번째 원소를 봅시다.

아래에 있는 배열에서는 첫 번째 원소와 대응됩니다. “방문 했음”으로 표시해줍시다.
첫 번째 원소보다 오른쪽에 위치한 원소 중 이미 방문한 원소의 개수가 1개 있습니다. 이것은, 위에 있는 배열과 전후 관계가 다른 원소가 1개 존재한다는 것이기 때문에 정답을 1 증가 시킵시다.

세 번째 원소를 봅시다.

네 번째 원소와 대응됩니다.
오른쪽에 있는 원소 중 이미 방문한 원소가 없으니 그냥 넘어갑니다.

네 번째 원소를 봅시다.

두 번째 원소와 대응됩니다.
오른쪽에 있는 원소 중 이미 방문한 원소가 2개 있습니다.
정답을 2 증가시킵시다.

마지막 원소를 봅시다.

다섯 번째 원소와 대응됩니다.
가장 오른쪽에 있는 원소이기 때문에 탐색을 종료합니다.

이렇게 하면 정답은 0+1+0+2+0 = 3입니다.

자신보다 오른쪽에 있는 원소 중 방문한 원소의 개수를 segment tree나 fenwick tree를 이용해 구할 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

vector<ll> tree, arr;
int h[1000010];

void update(int node, int start, int end, int idx, int diff){
	if(idx<start || idx>end) return;
	tree[node] += diff;
	if(start^end){
		update(node*2, start, (start+end)/2, idx, diff);
     update(node*2+1, (start+end)/2+1, end, idx, diff);
	}
}

ll sum(int node, int start, int end, int left, int right){
	if(left>end || right<start){
    	return 0;
	}
	if(left<=start && end<=right){
    	return tree[node];
	}
	return sum(node*2, start, (start+end)/2, left, right) + sum(node*2+1, (start+end)/2+1, end, left, right);
}

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	int n; cin >> n;
	arr = vector<ll>(n+10);
	tree = vector<ll>(4*n+40);
	for(int i=1; i<=n; i++){
		ll t; cin >> t;
		h[t] = i;
	}
	for(int i=1; i<=n; i++){
		ll t; cin >> t;
		arr[i] = h[t];
	}

	ll ans = 0;

	for(int i=1; i<=n; i++){
		int j = arr[i];
		ans += sum(1, 1, n, j+1, n);
		update(1, 1, n, j, 1);
	}

	cout << ans;
}