2021 KOI 2차대회 고등부 문제 풀이

문제 감상

성인이라서 대회에 참가하지는 못하지만, 2021 KOI 1차대회처럼 문제 풀이는 꾸준히 작성하려고 합니다.

2019년부터 1~2번 문제 난이도가 계속해서 상승하고 있습니다. 올해 고등부 1, 2번 모두 KOI가 4문제 체제로 자리잡은 2011년 이후에 출제된 고등부 1, 2번 문제 중 가장 어렵게 느껴졌습니다. 내년에는 어떻게 출제가 될지 궁금해지네요.

한편, 고등부 3번은 KOI에 거의 출제되지 않던 문자열 문제가 출제되었습니다. Trie나 KMP와 같은 비교적 쉽고 많은 학생들이 알고 있는 알고리즘이 아닌 Suffix Array / Longest Common Prefix가 나왔다는 점에서 큰 이변이라고 생각합니다. 올해 고등부 3번 문제는 SA/LCP를 알고 있는 학생이라면 10초 만에 풀이를 떠올릴 수 있지만, SA/LCP를 잘 이해하고 있는 학생은 한 손으로 셀 수 있을 정도로 적기 때문에 순위에 큰 영향을 주지 않을 것 같습니다.

고등부 4번은 2020 SciOI에 출제된 방역이라는 문제와 비슷해 보이는데, 어려워서 아직 풀지 못했습니다.

일단 1, 2, 3번 문제의 풀이를 작성했고, 4번 문제는 풀게 된다면 추가하도록 하겠습니다.
아직 문제를 채점할 수 있는 곳이 없어서 부득이하게 검증되지 않은 코드를 올려놓았습니다. 문제를 채점할 수 있는 곳이 생기면 코드를 검증한 뒤 다시 올리겠습니다.

1번

크기가 $1, 2, \cdots , i$인 동심원을 만드는데 필요한 페인트의 양은 $i(i+1)/2$이기 때문에, 크기가 $\lfloor \sqrt{A+B} \rfloor$인 원까지만 고려해도 충분합니다. $A+B \leq 100\,000$이므로 $450$까지만 고려하면 됩니다.

이 관찰을 이용하면 41점을 받을 수 있는 $O(AB \sqrt {A+B})$ 풀이는 쉽게 찾을 수 있습니다.
빨간색 페인트를 $a$, 파란색 페인트를 $b$만큼 사용해서 크기가 $1, 2, \cdots , i$인 동심원을 만드는 경우의 수를 $D(i,a,b)$라고 정의합시다. $D(i,a,b) = D(i-1,a-i,b)+D(i-1,a,b-i)$이므로 각 상태를 상수 시간에 계산할 수 있습니다.

각 상태는 이미 상수 시간에 계산할 수 있기 때문에 더 최적화를 하는 것은 불가능하고, 상태의 개수를 줄이는 방법을 생각해보는 것이 좋습니다. 조금만 고민해보면 90점을 받을 수 있는 $O(TA\sqrt{A+B})$풀이도 쉽게 찾을 수 있습니다.
빨간색 물감을 $a$만큼 사용하고, 파란색 물감은 $B$ 이하로 사용해서 크기가 $1,2,\cdots,i$인 동심원을 만드는 경우의 수를 $D(i,a)$라고 정의합시다. $i$번째 동심원까지 빨간색 물감을 $a$만큼 사용했다면 파란색 물감은 $i*(i+1)/2-a$만큼 사용한 것이기 때문에 $a$만 저장해도 충분합니다. 조건문이 몇 개 붙긴 하지만, 41점 풀이와 비슷하게 점화식을 계산할 수 있습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 1번 90점 풀이
inline void Add(int &a, int b){ a += b; if(a >= MOD) a -= MOD; }
int Score90(int A, int B){
    int ret = 0;
    D[0][0] = 1;
    for(int i=1; i<450; i++){
        int k = i * (i + 1) / 2;
        for(int j=0; j<=A; j++){
            D[i][j] = 0;
            if(j - i >= 0) Add(D[i][j], D[i-1][j-i]);
            if(k - j <= B) Add(D[i][j], D[i-1][j]);
            Add(ret, D[i][j]);
        }
    }
    return ret;
}

이 풀이는 테스트 케이스마다 $B$가 계속 달라질 수 있기 때문에, 기존에 구한 DP 테이블을 다른 테스트 케이스에서 재활용하지 못합니다. 만약 DP 테이블을 재활용할 수 있는 방법을 찾는다면 시간 복잡도에서 $T$를 없앨 수 있습니다.

$B$가 매번 다른 것이 문제이므로 $B$ 조건을 없앤 상태로 점화식을 세워봅시다. 빨간색 물감을 $a$만큼 사용하고, 파란색 물감은 자유롭게 사용해서 크기가 $1, 2, \cdots , i$인 동심원을 만드는 경우의 수를 $D(i,a)$라고 정의합시다. $D(i,a) = D(i-1,a-i) + D(i-1,a)$이고, $O(A \sqrt {A+B})$에 계산할 수 있습니다.
위 점화식을 계산한 DP 테이블을 갖고 있다면, 문제의 정답을 계산하는 것은 간단합니다. 아래 코드는 매번 $O(A \sqrt {A+B})$ 시간이 걸리기 때문에 여전히 90점을 받는 코드입니다.

1
2
3
4
5
6
7
8
int Query(int A, int B){
    int res = 0;
    for(int i=1; i<SZ; i++){
        int req = i * (i + 1) / 2;
        for(int j=0; j<=A; j++) if(req - j <= B) Add(res, D[i][j]);
    }
    return res;
}

위 코드에서 req - j <= B를 약간 변형하면 req - B <= j가 됩니다. 그러므로 위 코드는 사실 $i$마다 $D(i,\text{req}-B)$부터 $D(i,A)$까지의 합, 즉 $D(i, \ast)$의 부분합을 구하는 코드입니다. 그러므로 DP 테이블의 각 행마다 Prefix Sum을 계산하면 Query 함수를 $O(\sqrt {A+B})$에 작성할 수 있습니다.

전체 문제를 $O((T+AB) \sqrt {A+B})$에 해결할 수 있습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <bits/stdc++.h>
using namespace std;

constexpr int SZ = 450;
constexpr int MOD = 1e9+7;
inline void Add(int &a, int b){ a += b; if(a >= MOD) a -= MOD; }

int D[SZ][50001];
int Sum(int i, int s, int e){
    if(s > e) return 0;
    int res = D[i][e] - (s > 0 ? D[i][s-1] : 0);
    return res >= 0 ? res : res + MOD;
}

void Init(){
    D[0][0] = 1;
    for(int i=1; i<SZ; i++){
        for(int j=0; j<=50000; j++){
            D[i][j] = D[i-1][j];
            if(j - i >= 0) Add(D[i][j], D[i-1][j-i]);
        }
    }
    for(int i=1; i<SZ; i++) for(int j=1; j<=50000; j++) Add(D[i][j], D[i][j-1]);
}

int Query(int A, int B){
    int res = 0;
    for(int i=1; i<SZ; i++){
        int req = i * (i + 1) / 2;
        Add(res, Sum(i, req - B, A));
    }
    return res;
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    int T; cin >> T; Init();
    while(T--){
        int A, B; cin >> A >> B;
        cout << Query(A, B) << "\n";
    }
}

2번

정점 $u$에 배정된 가중치를 $x$라고 합시다. $u$와 가중치가 $e$인 간선으로 연결된 정점 $v$에는 $e-x$가 배정되고, $v$와 가중치가 $f$인 간선으로 연결된 정점 $w$에는 $f-(e-x) = -x+(f-e)$가 배정될 것입니다. 이렇게 가중치를 배정하는 과정을 생각해보면 아래와 같은 관찰을 할 수 있습니다.

한 정점의 가중치를 $x$라고 하면, 다른 모든 정점의 가중치를 $x+b$ 혹은 $-x+b$ 꼴로 표현할 수 있다.

위 관찰을 통해, $ax+b$ 꼴의 일차 함수가 $N$개 주어졌을 때 $\vert ax+b \vert$의 합을 최소로 하는 $x$를 찾는 문제고 바꿀 수 있습니다. (단, $\vert a \vert$ = 1) 먼저 일차 함수 $N$개를 찾는 방법을 알아봅시다.

임의의 정점 $v$를 잡은 뒤, $v$의 함수를 $1x+0(=x)$로 설정합시다. $v$부터 DFS를 하면서, 만약 $ax+b$를 함수로 갖는 $u$와 $w$가 가중치가 $e$인 간선으로 연결되어 있다면 $w$의 함수를 $-ax + (e-b)$로 설정합니다. 만약 그래프에 사이클이 없다면 이 과정이 순조롭게 진행되겠지만, 사이클이 있는 경우에는 몇 가지 문제가 발생하게 됩니다.

기존에 방문했던 정점 $c$로 돌아온 상황을 가정해봅시다. 또한, 기존에 구한 $c$의 함수를 $ax+b$, 이번에 새로 구한 함수를 $a’x+b’$이라고 합시다. 3가지 경우로 나눌 수 있습니다.

  1. $(a, b) = (a’, b’)$ : 동일한 정보가 들어온 경우입니다. 추가적인 처리 없이 넘어가면 됩니다.
  2. $a = a’ \land b \neq b’$ : $ax+b = a’x+b’$이 성립하는 $x$가 존재하지 않습니다. 모순이 생긴 것이므로 No를 출력하고 프로그램을 종료합니다.
  3. 그 외의 경우 : $ax+b = a’x+b’$이 성립하는 $x$가 정확히 하나 존재합니다. $x = (b’-b)/(a-a’)$이며, 어차피 $\vert a-a’ \vert = 2$이므로 2를 곱해서 저장해도 상관 없습니다.

위 3가지 경우를 처리하며 DFS를 모두 수행하고 나면 아래 3가지 상황 중 한 가지에 놓이게 됩니다.

  1. 3번이 여러 번 발생했고, 이때 $x$를 2가지 이상 발견한 경우 : 모순이 생긴 것이므로 No를 출력하고 프로그램을 종료합니다.
  2. 3번이 한 번 이상 발생해서 $x$를 정확히 하나 찾은 경우 : $x$를 찾았으므로 문제의 정답을 구할 수 있습니다.
  3. 3번이 한 번도 발생하지 않은 경우 : $\sum \vert ax+b \vert$가 최소가 되는 $x$(극점)를 직접 찾아야 합니다. 일차 함수의 절댓값은 unimodal하고, 이들의 합도 unimodal하기 때문에 삼분 탐색을 이용해 극점을 찾을 수 있습니다. 또는 $x$가 $-b/a$들의 중앙값이면 $\sum \vert ax+b \vert$가 최소가 된다는 것을 이용해 $x$를 구할 수도 있습니다.

DFS는 $O(N+M)$에 할 수 있고, $x$를 찾는 것은 $O(N \log N)$ 정도에 가능하므로 제한 시간 내에 문제를 해결할 수 있습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <bits/stdc++.h>
#define x first
#define y second
using namespace std;

using ll = long long;
using ld = long double;
using PLL = pair<ll, ll>;
inline void DIE(){ cout << "No"; exit(0); }

int N, M;
vector<PLL> G[101010];
PLL A[101010];
bool C[101010];
vector<ll> Fix, Can, Vis;
ll R[101010];

void DFS(int v, int mul, ll add){
    if(!C[v]) A[v] = {mul, add};
    else if(A[v].x == mul && A[v].y == add) return;
    else if(A[v].x == mul && A[v].y != add) DIE();
    else{
        Fix.push_back(2 * (A[v].y - add) / (mul - A[v].x));
        return;
    }
    Vis.push_back(v); C[v] = true;
    if(mul == 1) Can.push_back(-add);
    else Can.push_back(add);
    for(auto [i,w] : G[v]) DFS(i, -mul, w-add);
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> N >> M;
    for(int i=1; i<=M; i++){
        int s, e, x; cin >> s >> e >> x;
        G[s].emplace_back(e, x);
        G[e].emplace_back(s, x);
    }
    for(int i=1; i<=N; i++){
        if(C[i]) continue;
        Fix.clear(); Can.clear(); Vis.clear();
        DFS(i, 1, 0);
        sort(Can.begin(), Can.end());
        sort(Fix.begin(), Fix.end());
        Fix.erase(unique(Fix.begin(), Fix.end()), Fix.end());
        if(Fix.size() > 1) DIE();
        if(Fix.size() == 1 && Fix[0] % 2 != 0) DIE();
        double X = Fix.size() == 1 ? Fix[0]/2 : Can[Can.size()/2];
        for(auto j : Vis) R[j] = A[j].x * X + A[j].y;
    }
    cout << "Yes\n";
    for(int i=1; i<=N; i++) cout << R[i] << " ";
}

3번

문제를 10초만 쳐다보면 풀이를 알 수 있습니다. 잘 모르겠다면 Suffix Array와 Longest Common Prefix를 구하는 방법을 공부한 뒤, BOJ 9249 최장 공통 부분 문자열을 풀어보면 됩니다. SA와 LCP를 이용해 최장 공통 부분 문자열을 구하는 방법을 응용하면 최장 공통 괄호 부분 문자열도 구할 수 있습니다.

SA 상에서 인접한 접미사(Suffix) 간의 최장 공통 접두사(LCP)를 모두 확인합시다. 각 LCP마다 가장 긴 괄호 부분 문자열을 구하면, 이들 중 최댓값이 문제의 정답이 됩니다.

문자열이 주어졌을 때 임의의 부분 문자열에서 최장 괄호 부분 문자열을 구하는 쿼리는 $O(\log N)$에 할 수 있으므로 SA와 LCP를 구해놓았다면 문제를 $O(N \log N)$에 해결할 수 있습니다.

SA와 LCP를 구하는 것은 $O(N), O(N \log N), O(N log^2 N)$ 등 여러가지 방법이 있는데, 실제 대회에 참가하지 않아서 어디까지 허용되는지 모르겠습니다. 아래 코드는 $O(N \log^2 N)$으로 구현된 코드입니다.

해싱을 이용한 똑똑한 풀이가 있다고 하던데 저는 잘 모르겠습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <bits/stdc++.h>
using namespace std;

vector<int> getSA(const string &s){
    int n = s.size(), len = 1;
    vector<int> sa(n), pos(n), tmp(n);
    for(int i=0; i<n; i++) sa[i] = i, pos[i] = s[i];

    auto compare = [&](int a, int b){
        if(pos[a] != pos[b]) return pos[a] < pos[b];
        if(a+len < n && b+len < n) return pos[a+len] < pos[b+len];
        return a > b;
    };

    for(len=1; ; len<<=1){
        sort(sa.begin(), sa.end(), compare);
        for(int i=0; i+1<n; i++) tmp[i+1] = tmp[i] + compare(sa[i], sa[i+1]);
        for(int i=0; i<n; i++) pos[sa[i]] = tmp[i];
        if(tmp.back() == n-1) break;
    }
    return move(sa);
}

vector<int> getLCP(const string &s, const vector<int> &sa){
    int n =s.size();
    vector<int> pos(n), lcp(n-1);
    for(int i=0; i<n; i++) pos[sa[i]] = i;
    for(int i=0, k=0; i<n; i++, k=max(k-1, 0)){
        if(pos[i] == n-1) continue;
        for(int j=sa[pos[i]+1]; s[i+k]==s[j+k]; k++);
        lcp[pos[i]] = k;
    }
    return move(lcp);
}

struct Node{
    int res, opn, cls;
    void init(char c){ res = 0; opn = c == '('; cls = ')'; }
} T[1 << 21];

Node Merge(const Node &l, const Node &r){
    int add = min(l.opn, r.cls);
    return { l.res + r.res + add*2, l.opn + r.opn - add, l.cls + r.cls - add };
}

void Build(int node, int s, int e, const string &str){
    if(s == e){ T[node].init(str[s]); return; }
    int m = s + e >> 1;
    Build(node << 1, s, m, str);
    Build(node << 1 | 1, m+1, e, str);
    T[node] = Merge(T[node << 1], T[node << 1 | 1]);
}

Node Query(int node, int s, int e, int l, int r){
    if(r < s || e < l) return { 0, 0, 0 };
    if(l <= s && e <= r) return T[node];
    int m = s + e >> 1;
    return Merge(Query(node << 1, s, m, l, r), Query(node << 1 | 1, m+1, e, l, r));
}

void Solve(){
    string a, b; cin >> a >> b;
    string s = a + "$" + b;
    int N = s.size(), A = a.size(), B = b.size();

    auto sa = getSA(s), lcp = getLCP(s, sa);
    vector<int> pos(s.size());
    for(int i=0; i<N; i++) pos[sa[i]] = i;
    Build(1, 0, N-1, s);

    int res = 0;
    for(int i=0; i+1<N; i++){
        if(sa[i] == A || sa[i+1] == A) continue;
        if((sa[i] < A) == (sa[i+1] < A)) continue;
        int st = min(sa[i], sa[i+1]), len = lcp[i];
        int now = Query(1, 0, N-1, st, st+len-1).res;
        res = max(res, now);
    }
    cout << res << "\n";
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    int TC; cin >> TC;
    while(TC--) Solve();
}