백준15977 조화로운 행렬

문제 링크

  • http://icpc.me/15977

문제 출처

  • 2018 KOI 고등부 3번

시간복잡도

  • O(Nlg2N)

풀이

1행을 기준으로 정렬을 해봅시다.
M = 2인 경우에는 O(NlgN)만에 LIS를 구해주면 쉽게 풀 수 있습니다.
M = 2인 경우는 너무 간단하니까, M = 3인 경우의 풀이만 살펴봅시다.

첫 번째 행을 A[], 두 번째 행을 B[], 세 번째 행을 C[]라고 합시다.
A[i] < A[j] and B[i] < B[j] and C[i] < C[j]를 만족하는 부분 수열의 최대를 구해야합니다.
이미 A[] 기준으로 이미 정렬을 했기 때문에 B[i] < B[j] and C[i] < C[j]만 고려해주면 됩니다.
pair에 대한 LIS를 구하는 문제로 환원이 됩니다.

LIS를 세그먼트 트리로 구할 수 있듯이 pair의 LIS도 2D 세그먼트 트리로 O(Nlg2N)에 구할 수 있지만, 다른 방법을 생각해봅시다.

lis[i] = i번째로 끝나는 lis의 길이로 정의를 하고 관찰을 해봅시다.
어떤 K에 대해, lis[i] = K를 만족하는 (B[i], C[i])쌍을 모두 모아서 보면 B[i]가 증가할수록 항상 C[i]가 감소합니다.
B[i]가 C[i]가 감소한다는 것은, 2차원 좌표 평면 상에 점을 찍었을 때 계단 모양이 나온다는 것을 의미합니다.
이런 점들은 set/map을 통해 관리해줄 수 있습니다.
lis는 최대 20만까지 set을 20만 개 만들어서 각각 점들을 관리해주면 됩니다.

lis[i] = K가 되기 위해서는 j < i and lis[j] = K-1를 만족하는 j가 존재해야 합니다. lis는 최장 증가 부분 수열이므로 이진탐색을 통해 i보다 앞에 올 수 있으면서 lis의 값이 가장 큰 값 X를 찾아준다면, lis[i]는 X+1이 됩니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define pb push_back
using namespace std;

typedef long long ll;
typedef pair<ll, ll> p;

struct asdf{
    int a, b, c;
    asdf(int a = 0, int b = 0, int c = 0) : a(a), b(b), c(c) {}
    bool operator < (const asdf &t) const {
        if(a != t.a) return a < t.a;
        if(b != t.b) return b < t.b;
        return c < t.c;
    }
};

int k, n;
int arr[4][202020];
vector<asdf> tmp;
vector<p> v;
set<p> st[202020];

int solve(){
    int ans = 0;
    for(int i=0; i<n; i++){
        int l = 1, r = ans + 1;
        while(l < r){
            int m = l + r >> 1;
            auto it = st[m].lower_bound(v[i]);
            if(it == st[m].begin()) r = m;
            else{
                if(prev(it)->y < v[i].y) l = m + 1;
                else r = m;
            }
        }
        ans = max(ans, l);
        auto it = st[l].insert(v[i]).first;
        it++;
        while(it != st[l].end() && it->y >= v[i].y){
            auto tmp = it; it++;
            st[l].erase(tmp);
        }
    }
    return ans;
}

int main(){
    ios_base::sync_with_stdio(0); cin.tie(0);
    cin >> k >> n;
    for(int i=1; i<=3; i++) for(int j=1; j<=n; j++){
        if(k < i) arr[i][j] = arr[i-1][j];
        else cin >> arr[i][j];
    }

    for(int i=1; i<=n; i++) tmp.pb({arr[1][i], arr[2][i], arr[3][i]});
    sort(all(tmp));
    for(auto &i : tmp) v.pb({i.b, i.c});
    cout << solve();
}