백준17677 케이크 3

문제 링크

  • http://icpc.me/17677

문제 출처

  • 2018/2019 JOISC Day4 1번

사용 알고리즘

  • PST
  • Divide and Conquer Optimization

시간복잡도

  • $O(N \log^2 N)$

풀이

일단 어떻게 해서 $M$개를 잘 선택했다고 합시다. $V_i$의 합은 상수이므로, $M$개를 잘 배치해서 인접한 $C_i$들의 차이의 합을 최소화해야 합니다.
조금만 생각해보면 $C_i$를 오른차순 혹은 내림차순으로 정렬하는 것이 최적임을 할 수 있고, 이때 인접한 원소의 차이의 합은 $2(\max C_i - \min C_i)$가 됩니다.

$M$개를 잘 선택해서 (V의 합 - C의 최대 + C의 최소)를 최대화하는 문제로 바뀌게 됩니다.
서브태스크를 따라가면서 문제를 풀어봅시다.

Subtask 1) N ≤ 100, 5점

선택할 $C_i$의 최댓값과 최솟값을 고정시키고, 선택 가능한 원소 중 $V_i$값이 가장 큰 $M$개를 택하면 됩니다. $O(N^3)$에 구현할 수 있습니다.

Subtask 2) N ≤ 2 000, 24점

$(C_i, V_i)$ 순으로 정렬합시다. $i < j$인 i, j를 선택하는 것은 subtask 1에서 $C_i$의 최대/최소를 선택하는 것과 똑같은 행위입니다.
$i$를 고정시키고, $j$를 하나씩 증가시키면서 priority_queue나 multiset으로 가장 큰 $M$개의 $V_i$값을 관리해주면 $O(N^2 \log N)$에 문제를 풀 수 있습니다. (자료구조에 원소를 최대 $O(N^2)$번 삽입/삭제)

Subtask 3) N ≤ 200 000, 100점

$O(N^2)$개의 구간을 모두 보는 방법으로는 절대 문제를 풀 수 없습니다. 고려해야하는 구간의 개수를 줄여야 합니다.
왼쪽 끝점 $i$를 고정했을 때 최적이 되는 가장 작은 $j$를 opt[i]라고 합시다. opt[i] ≤ opt[i+1]이 성립하는 것은 직관적으로 알 수 있습니다. 최적 위치가 단조증가하므로 Divide and Conquer Optimization을 적용할 수 있습니다.

원소들이 $C_i$ 기준으로 정렬되어 있으므로, 아래와 같은 연산만 빠르게 처리할 수 있으면 정답을 구할 수 있습니다.

  • $i \in [s, e]$인 $i$에 대해 $V_i$들 중에서 가장 큰 $M$의 합 구하기

이 연산은 Persistent Segment Tree를 이용해 $O(\log N)$만에 할 수 있습니다.

$T(N) = 2T(N/2) + O(N \log N)$이므로 $O(N \log^2 N)$만에 문제를 해결할 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <bits/stdc++.h>
#define x first
#define y second
#define c first
#define v second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
#define IDX(v, x) lower_bound(all(v), x) - v.begin()
using namespace std;

typedef long long ll;
typedef pair<ll, ll> p;

struct Node{ int l, r, sz; ll v; } nd[5050505];
int rt[202020], pv;
vector<int> comp;

void update(int now, int prv, int s, int e, int x){
    nd[now].sz = nd[prv].sz + 1;
    nd[now].v = nd[prv].v + comp[x];
    if(s == e) return;
    int m = s + e >> 1;
    if(x <= m){
        if(!nd[now].l) nd[now].l = ++pv, nd[now].r = nd[prv].r;
        update(nd[now].l, nd[prv].l, s, m, x);
    }
    else{
        if(!nd[now].r) nd[now].r = ++pv, nd[now].l = nd[prv].l;
        update(nd[now].r, nd[prv].r, m+1, e, x);
    }
}
ll query(int now, int prv, int s, int e, ll k){
    if(s == e) return comp[s] * k;
    int m = s + e >> 1;
    int ri = nd[nd[now].r].sz - nd[nd[prv].r].sz;
    if(k <= ri) return query(nd[now].r, nd[prv].r, m+1, e, k);
    else return query(nd[now].l, nd[prv].l, s, m, k-ri) + nd[nd[now].r].v - nd[nd[prv].r].v;
}

int n, m, opt[202020]; p a[202020];
ll mx = -1e18;

void f(int s, int e, int l, int r){
    if(s > e) return;
    int mid = s + e >> 1; ll now = -1e18; opt[mid] = r;
    for(int i=max(l, mid+m-1); i<=r; i++){
        ll t = query(rt[i], rt[mid-1], 0, comp.size()-1, m);
        t -= a[i].c - a[mid].c;
        if(now < t) now = t, opt[mid] = i;
    }
    mx = max(mx, now);
    f(s, mid-1, l, opt[mid]);
    f(mid+1, e, opt[mid], r);
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n >> m;
    for(int i=1; i<=n; i++){ cin >> a[i].v >> a[i].c; a[i].c *= 2; comp.push_back(a[i].v); }
    sort(a+1, a+n+1); compress(comp);
    for(int i=0; i<=n; i++) rt[i] = ++pv;
    for(int i=1; i<=n; i++) update(rt[i], rt[i-1], 0, comp.size()-1, IDX(comp, a[i].v));
    f(1, n, 1, n);
    cout << mx;
}