백준14460 소가 길을 건너간 이유 12

문제 링크

  • http://icpc.me/14460

문제 출처

  • 2017 Feb USACO 3번

사용 알고리즘

  • 분할 정복
  • 세그먼트 트리

시간복잡도

  • $O(N \log^2 N)$

풀이

$i$번째 소의 왼쪽 목초지 자리를 $A_i$, 오른쪽 목초지의 자리를 $B_i$라고 합시다. 두 소 $i, j$가 교차하면 $A_i < A_j$ 일 때 $B_i > B_j$입니다.
즉, $(A_i, B_i)$를 좌표 평면 위에 플로팅했을 때 자신보다 오른쪽 아래에 있는 점들 중, $|i - j| \leq K$인 개수를 구해서 더하면 됩니다.

2D Segment Tree를 쓰면 쉽게 풀리지만, Dynamic 2D Segment Tree를 구현하는 건 귀찮으므로 분할 정복을 합시다.

$[s, e]$ 구간의 중간 $m = \lfloor \frac{s + e}{2} \rfloor$를 잡읍시다.
$[s, m]$과 $[m+1, e]$ 구간 내부에서 발생하는 교차 횟수는 재귀적으로 처리하면 되므로, 두 구간 사이에서 발생하는 교차 횟수만 셀 것입니다.

먼저, $[s, m]$ 구간에 있는 원소와 $[m+1, e]$ 구간에 있는 원소들을 각각 y좌표 기준으로 정렬을 합시다.
구간의 합을 관리하는 Segment Tree를 만든 뒤, $[s, m]$, $[m+1, e]$ 구간의 원소들을 merge sort 하듯이 훑어주면서

  • $[m+1, e]$ 구간에 속하는 원소의 경우에는 세그먼트 트리에서 현재 원소의 “인덱스” 위치에 1을 증가시키고
  • $[s, m]$ 구간에 속하는 원소의 경우에는 세그먼트 트리에서 $[1, i-k-1]$, $[i+k+1, N]$ 범위에 쿼리를 날려서 정답을 구해주면 됩니다.

자세한 구현은 코드를 참고해주세요.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// 14460 소가 길을 건너간 이유 12

#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
using namespace std;

typedef long long ll;
typedef pair<ll, ll> p;

int tree[1 << 18];
const int sz = 1 << 17;
void update(int x, int v){
    x |= sz; tree[x] += v;
    while(x >>= 1) tree[x] = tree[x << 1] + tree[x << 1 | 1];
}
int query(int l, int r){
    if(l > r) return 0;
    l |= sz; r |= sz;
    int ret = 0;
    while(l <= r){
        if(l & 1) ret += tree[l++];
        if(~r & 1) ret += tree[r--];
        l >>= 1; r >>= 1;
    }
    return ret;
}

int n, k;
int a[101010], b[101010];
p v[101010];
ll ans;
vector<p> l, r;

void f(int s, int e){
    if(s == e) return;
    int m = s + e >> 1, j = 0;
    l.clear(); r.clear();
    for(int i=s; i<=m; i++) l.push_back(v[i]);
    for(int i=m+1; i<=e; i++) r.push_back(v[i]);
    sort(all(l)); sort(all(r));

    for(int i=0; i<l.size(); i++){
        while(j < r.size() && r[j].x < l[i].x) update(r[j].y, 1), j++;
        ans += query(1, l[i].y-k-1);
        ans += query(l[i].y+k+1, n);
    }
    for(int i=0; i<j; i++) update(r[i].y, -1);
    f(s, m); f(m+1, e);
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n >> k;
    for(int i=1; i<=n; i++){
        int t; cin >> t; a[t] = i;
    }
    for(int i=1; i<=n; i++){
        int t; cin >> t; b[t] = i;
    }
    for(int i=1; i<=n; i++) v[a[i]] = {b[i], i};
    f(1, n);
    cout << ans;
}