백준14240 부분 수열의 점수

문제 링크

  • http://icpc.me/14240

사용 알고리즘

  • DP
  • CHT

시간복잡도

  • O(N log N)

풀이

A_n = sum(S_1, S_2, S_3, … , S_n)으로 정의합시다.
비슷한 방식으로 B_n = sum(A_1, A_2, A_3, … , A_n)으로 정의합시다.

[i, j] 구간의 점수 score(i, j)는 ( A_j - A_(j-1) ) + ( A_j - A_(j-2) ) + ( A_j - A_(j-3) ) + … + ( A_j - A_(i-1) )입니다.
score(i, j) = A_j * (j-1+i) - B_(j-1) + B_(i-2)로 나타낼 수 있습니다.

dp[n]을 n으로 끝나는 부분 수열의 최댓값이라고 하면, dp[n] = max(score(i, n))입니다. (단, 1 <= i < n)
dp[n] = max( A_n * (n-i+1) - B_(n-1) + B_(i-2) )가 되고,
dp[n] = max( A_n * (-i) + B_(i-2) ) + A_n * (n+1) - B_(n-1)이 됩니다.

max() 안에 있는 식을 기울기가 -i, y절편이 B_(i-2)인 일차 함수로 나타내면 CHT를 적용할 수 있게 됩니다.
CHT를 이용해 O(N log N)만에 답을 구할 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <bits/stdc++.h>
#define x first
#define y second
#define int long long
using namespace std;

typedef long long ll;
typedef pair<ll, ll> p;
const ll inf = 3e17;

struct CHT{
	bool isInc;
	CHT(){}
	CHT(bool _isInc){
		isInc = _isInc;
	}
	deque<p> line;
	double inter(int i, int j){
		return 1.00 * (line[i].y - line[j].y) / (line[j].x - line[i].x);
	}
	ll calc(ll i, ll x){
		return line[i].x * x + line[i].y;
	}
	void insert(ll a, ll b){
		line.push_back({a, b});
		int i = line.size() - 1;
		while(i > 1 && inter(i, i-1) < inter(i-1, i-2)){
			line[i-1] = line.back();
			line.pop_back();
			i--;
		}
	}
	int bin(ll k){
		int l = 0;
		int r = line.size() - 1;
		while(l < r){
			int m = l + r >> 1;
			if (k < inter(m, m+1)) r = m;
			else l = m + 1;
		}
		return r;
	}
	ll get(ll k){
		if(isInc){
			if(line.empty()) return inf; //assert
			while(line.size() > 1 && calc(0, k) > calc(1, k)){
				line.pop_front();
			}
			return calc(0, k);
		}
		if(line.empty()) return inf; //assert
		if(line.size() == 1) return calc(0, k);
		return calc(bin(k), k);
	}
} cht;

ll s[202020];
ll a[202020];
ll b[202020];
int n;

signed main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	cin >> n;
	for(int i=1; i<=n; i++){
		cin >> s[i]; s[i] = -s[i];
		a[i] = a[i-1] + s[i];
		b[i] = b[i-1] + a[i];
	}

	ll ans = min(0ll, s[1]);

	cht = CHT(0);
	cht.insert(-1, 0);

	for(int i=2; i<=n; i++){
		ll now = cht.get(a[i]) + a[i] * (i+1) - b[i-1];
		if (ans > s[i]) ans = s[i];
		if (ans > now) ans = now;
		cht.insert(-i, b[i-2]);
	}
	cout << -ans;
}