백준17823 수열과 쿼리 33

문제 링크

  • http://icpc.me/17823

사용 알고리즘

  • Aliens Trick
  • Parallel Binary Search

시간복잡도

  • $O(N \log N \log X)$

풀이

구간의 양끝 원소의 사용 여부를 고정합시다. 각 구간의 양끝 원소의 사용 여부는 0부터 3까지의 정수로 인코딩할 수 있습니다. 또한, 구간의 양끝 사용 여부를 고정한 뒤 그래프를 잘 모델링하면 MCMF로 각 쿼리를 해결할 수 있습니다.

“구간 $[l, r]$의 양끝 사용 여부가 $bit$일 때 $k$개의 부분 수열의 합의 최댓값”을 $f(l, r, bit, k)$로 정의하면, MCMF로 문제를 해결할 수 있기 때문에 $f(l,r,bit,k)-f(l,r,bit,k-1) \geq f(l,r,bit,k+1)-f(l,r,bit,k)$가 성립해서 Aliens Trick을 적용할 수 있다는 것을 알 수 있습니다.

쿼리가 여러 개 주어지기 때문에 Parallel Binary Search를 생각해봅시다. PBS에서 우리가 수행해야 하는 연산은, “부분 수열 하나를 사용할 때 추가되는 비용이 $C$일 때 최적값”을 구하는 연산입니다.

각 정점에서 구간 $[l, r]$의 $f(l, r, bit, 1), f(l, r, bit, 2), \cdots , f(l, r, bit, r-l+1)$을 관리하는 세그먼트 트리를 만듭시다. 만약 이러한 세그먼트 트리를 전처리했다면, PBS의 각 결정 문제는 쿼리로 주어진 구간을 나타내는 $O(\log N)$개의 정점에 저장된 함수들을 보며 해결할 수 있습니다.
이 세그먼트 트리는 $O(N \log N)$에 만들 수 있습니다. 두 구간 $[l, m], [m+1, r]$을 합치는 것이 병목이 될텐데, 이 부분은 두 볼록 함수의 민코프스키 합($R_{i+j} = \max(A_i + B_j)$)를 계산하는 것이기 때문에 선형 시간에 합칠 수 있습니다. 그러므로 세그먼트 트리를 전처리하는 시간 복잡도는 $T(N) = 2T(N/2) + O(N) = O(N \log N)$이 됩니다. 다만, $m$과 $m+1$을 모두 사용하는 경우에는 두 구간을 concat해주면 사용하는 부분 수열의 개수가 한 개 감소하기 때문에 배열을 왼쪽으로 한 칸 shift 해야합니다.

이제 PBS를 합시다. 구간 개수 $k$와 추가 비용 $C$에 대해, $kC + f(k)$ 꼴의 최댓값을 계산하는 작업을 수행해야 하고, 이는 Convex Hull Trick과 매우 유사합니다. 심지어 세그먼트 트리의 각 정점에는 이미 볼록 함수가 저장되어 있기 때문에, 그냥 쿼리에 달려있는 $C$를 다 모아서 정렬한 뒤 선형 CHT 느낌으로 계산하면 됩니다.

PBS의 각 iteration마다 $O(N \log N)$ 시간이 걸리므로 전체 시간 복잡도는 $O(N \log N \log X)$가 됩니다. 물론, 다양한 bit의 조합을 계산해야 하기 때문에 세그먼트 트리를 전처리하는 과정과 PBS의 iteration에서 상수가 8 정도 더 붙습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
#define IDX(v, x) (lower_bound(all(v), x) - v.begin())
using namespace std;

using uint = unsigned;
using ll = long long;
using ull = unsigned long long;
using PLL = pair<ll, ll>;
constexpr int SZ = 1 << 15;
constexpr int INF32 = 0xc03f3f3f;
constexpr ll INF = 0xc03f3f3f3f3f3f3f;

PLL operator + (const PLL &p1, const PLL &p2){ return {p1.x + p2.x, p1.y + p2.y}; }
PLL operator - (const PLL &p1, const PLL &p2){ return {p1.x - p2.x, p1.y - p2.y}; }

struct Node{
    int sz;
    vector<ll> V[2][2];
    Node() : Node(0) {}
    Node(int sz) : sz(sz) {
        for(int i=0; i<2; i++) for(int j=0; j<2; j++) V[i][j] = vector<ll>(sz+1, -INF);
    }
    int size() const { return sz; }
    void reset(int _sz){
        sz = _sz;
        for(int i=0; i<2; i++) for(int j=0; j<2; j++) V[i][j] = vector<ll>(sz+1, -INF);
    }
    void setValue(ll v){
        reset(1);
        for(int i=0; i<2; i++){
            for(int j=0; j<2; j++){
                V[i][j][0] = (i || j) ? -INF : 0;
                V[i][j][1] = v;
            }
        }
    }
};

Node operator + (const Node &A, const Node &B){
    static auto dx = [](const vector<ll> &vec){
        vector<ll> ret; ret.reserve(vec.size()-1);
        for(int i=1; i<vec.size(); i++) ret.push_back(vec[i] - vec[i-1]);
        return move(ret);
    };

    Node R(A.size() + B.size());

    for(int i=0; i<2; i++){
        for(int j=0; j<2; j++){
            for(int k=0; k<2; k++){
                auto da = dx(A.V[i][j]), db = dx(B.V[j][k]);
                vector<ll> res(da.size() + db.size() + 1);
                res[0] = A.V[i][j][0] + B.V[j][k][0];
                merge(all(da), all(db), res.begin()+1, greater<>());
                partial_sum(all(res), res.begin());
                if(j) res.erase(res.begin());
                for(int s=0; s<res.size(); s++) R.V[i][k][s] = max(R.V[i][k][s], res[s]);
            }
        }
    }
    return R;
}

constexpr PLL BASE = PLL(-INF, 0);
struct Info{
    PLL V[2][2];
    Info() : Info(BASE, BASE, BASE, BASE) {}
    Info(PLL a, PLL b, PLL c, PLL d){
        V[0][0] = a; V[0][1] = b;
        V[1][0] = c; V[1][1] = d;
    }
    PLL max() const { return std::max({V[0][0], V[0][1], V[1][0], V[1][1]}); }
};

Info Merge(const Info &A, const Info &B, const ll lambda){
    Info R;
    for(int i=0; i<2; i++){
        for(int j=0; j<2; j++){
            for(int k=0; k<2; k++){
                PLL now = A.V[i][j] + B.V[j][k];
                if(j) now = now - PLL(lambda, 1);
                R.V[i][k] = max(R.V[i][k], now);
            }
        }
    }
    return R;
}

int N, Q, A[SZ];
ll S[SZ], E[SZ], K[SZ], L[SZ], R[SZ], M[SZ];
Info X[SZ];

Node T[SZ << 1];
int pv[SZ << 1][2][2]; // for CHT
vector<int> nds[SZ];

void Init(int node, int s, int e){
    if(s == e){
        T[node].setValue(A[s]);
        return;
    }
    int m = s + e >> 1;
    Init(node << 1, s, m);
    Init(node << 1 | 1, m+1, e);
    T[node] = T[node << 1] + T[node << 1 | 1];
}
void Insert(int node, int s, int e, int l, int r, int v){
    if(r < s || e < l) return;
    if(l <= s && e <= r){
        nds[v].push_back(node);
        return;
    }
    int m = s + e >> 1;
    Insert(node << 1, s, m, l, r, v);
    Insert(node << 1 | 1, m+1, e, l, r, v);
}

PLL CHT(int node, int i, int j, int q){
    auto &hull = T[node].V[i][j];
    auto &cnt = pv[node][i][j];
    while(cnt+1 < hull.size() && hull[cnt] <= hull[cnt+1] + M[q]) cnt++;
    ll dp = hull[cnt] + M[q] * cnt;
    return PLL(dp, cnt);
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> N >> Q;
    for(int i=1; i<=N; i++) cin >> A[i];
    Init(1, 1, N);
    for(int i=1; i<=Q; i++) cin >> S[i] >> E[i] >> K[i], Insert(1, 1, N, S[i], E[i], i);
    for(int i=1; i<=Q; i++) L[i] = -INF32, R[i] = INF32;

    vector<int> O(Q);
    iota(all(O), 1);

    while(true){
        bool fin = true;
        for(int i=1; i<=Q; i++){
            if(L[i] != R[i]) fin =  true;
            if(L[i] + R[i] & 1) M[i] = (L[i] + R[i] - 1) / 2;
            else M[i] = (L[i] + R[i]) / 2;
        }
        memset(pv, 0, sizeof pv);
        sort(all(O), [&](int a, int b){ return M[a] < M[b]; });
        for(auto q : O){
            bool flag = true;
            for(auto nd : nds[q]){
                Info now;
                for(int i=0; i<2; i++){
                    for(int j=0; j<2; j++){
                        now.V[i][j] = CHT(nd, i, j, q);
                    }
                }
                if(flag) X[q] = now;
                else X[q] = Merge(X[q], now, M[q]);
                flag = false;
            }
            if(X[q].max().y >= K[q]) R[q] = M[q];
            else L[q] = M[q] + 1;
        }
        if(fin) break;
    }
    for(int i=1; i<=Q; i++){
        cout << X[i].max().x - K[i] * L[i] << "\n";
    }
}