백준5466 Salesman

문제 링크

  • http://icpc.me/5466

문제 출처

  • 2009 IOI Day2 4번

사용 알고리즘

  • Segment Tree
  • DP

시간복잡도

  • O(n log n)

풀이

2009년 국제 정보 올림피아드 문제입니다.

풀이에 앞서 몇 가지 변수를 정의하도록 하겠습니다.

  • l[x] : x번째 시장의 위치
  • m[x] : x번째 시장에서 얻을 수 있는 이득

몇 가지 단계로 나누어 풀이를 해봅시다.

subtask 1. n≤5,000 / 같은 시간에 최대 1개의 시장 개최 (15점)

문제 해석 시간을 포함해서 이 서브테스크를 푸는데는 30분이 걸렸습니다.
dp[i] = i번째 시장까지 갈 때 최댓값 으로 정의하고, 시장의 순서는 시간 순으로 정렬해줍시다.
강을 거슬러 올라갈 때는 미터당 U달러, 반대의 경우에는 D달러가 소요됩니다.

현재 방문해야할 시장의 번호를 i, 이전에 방문했던 시장의 번호를 j라고 정의합시다. l[i] > l[j]라면 이동에 소요되는 비용은 d * (l[i] - l[j])입니다. 반대의 경우는 u * l[j] - l[i]입니다. 이를 통해 아래와 같은 점화식을 세울 수 있고, O(n2)만에 점화식을 모두 채울 수 있게 됩니다.

dp[i] = max{ dp[j] - |l[i]-l[j]| * (l[i]<=l[j]?u:d) } + m[i] (0 <= j < i)
0번째 시장의 위치를 s, 이득을 0으로 해줘서 시작점을 추가합시다.
마찬가지로 n+1번째 시장의 위치를 s, 이득을 0으로 해줘서 도착점을 추가합시다.

dp[n+1]을 구하면 답이 나오게 됩니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include <bits/stdc++.h>
using namespace std;

struct asdf{
	int t, l, m; //time, location, money
	bool operator < (asdf &rhs){
		return t < rhs.t;
	}
};

int dp[505050];

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	int n, u, s, d; cin >> n >> u >> d >> s;
	vector<asdf> v(n);
	for(int i=0; i<n; i++){
		cin >> v[i].t >> v[i].l >> v[i].m;
	}
	v.push_back({0, s, 0}); //시점
	v.push_back({500010, s, 0}); //종점

	sort(v.begin(), v.end());
	dp[0] = 0;
	for(int i=1; i<=n+1; i++){
		dp[i] = -1e9;
		for(int j=0; j<i; j++){
			int dist = abs(v[i].l - v[j].l);
			int baseCost = v[i].l <= v[j].l ? u : d;
			dp[i] = max(dp[i], dp[j] - dist*baseCost + v[i].m);
		}
	}

	cout << dp[n+1];
}
subtask 2. n≤300,000 / 같은 시간에 최대 1개의 시장 개최 (60점)

이 서브테스크를 푸는데 80분이 걸렸습니다.
O(n2)짜리 점화식을 O(n log n)까지 줄여야 합니다. Divide and Conquer Optimization(설명)을 생각해봤지만, 조건을 충족하지 못합니다.
일단 식을 분해해보았습니다.

  1. (l[i] > l[j])일 때
    dp[i] = max{ dp[j] - ( l[i] - l[j] ) * D } + m[i]
    dp[i] = max{ dp[j] + D*l[j] } - D*l[i] + m[i]
  2. (l[i] < l[j])일 때
    dp[i] = max{ dp[j] - ( l[j] - l[i] ) * V } + m[i]
    dp[i] = max{ dp[j] - U*l[j] } + U*l[i] + m[i]

i가 고정이 되어있다면, max 밖에 있는 것들은 상수취급할 수 있습니다. 그러면 max안에 있는, 즉 dp[j] + D*l[j]dp[j] - U*l[j]들의 최댓값만 O(log n)에 구할 수 있게 된다면 이 문제는 O(n log n)에 풀리게 될 것 입니다.
특정 구간 내에서 최댓값을 O(log n)만에 구할 수 있는 것은 Segment Tree 가 있습니다.
세그먼트 트리를 두 개 만들어서 하나는 dp[j] + D*l[j]를 저장하고, 다른 한 곳에는 dp[j] - U*l[j]를 저장하면 2번 섭테도 해결할 수 있습니다.

세그먼트 트리의 이름을 s1과 s2로 가정하면 아래와 같은 방법으로 구현할 수 있습니다.

  1. dp[i] = max{ s1.query(0,l[j])-d*l[i], s2.query(l[j],500001)+u*l[i] } + m[i]
  2. s1.update(l[i], dp[i]+d*l[i]); s2.update(l[i], dp[i]-u*l[i])

코드는 아래와 같습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <bits/stdc++.h>
using namespace std;

struct asdf{
	int t, l, m; //time, location, money
	bool operator < (asdf &rhs){
		return t < rhs.t;
	}
};

const int MAX_N = 1<<19;
struct SegTree {
	int seg[MAX_N*2-1];
	SegTree() {
		for(int i=0; i<MAX_N*2-1; i++) seg[i] = -1e9;
	}

	int query(int a, int b, int k=0, int l=0, int r=MAX_N){
    	if (b <= l || r <=a)return -1e9;
    	if (a<=l&&r<=b) return seg[k];
		return max(query(a,b,k*2+1,l,(l+r)/2), query(a,b,k*2+2,(l+r)/2,r));
	}

	void update(int k, int v) {
    	k += MAX_N-1;
    	seg[k] = v;
		while (k > 0) {
			k=(k-1)/2;
			seg[k]=max(seg[k*2+1],seg[k*2+2]);
		}
	}
};

SegTree s1, s2;
vector<asdf> v;
int dp[505050];
int n, u, s, d;

void update(int i){
	s1.update(v[i].l, dp[i] + d * v[i].l);
	s2.update(v[i].l, dp[i] - u * v[i].l);
}

int get(int i){
	int t1 = s1.query(0, v[i].l) - d*v[i].l + v[i].m;
	int t2 = s2.query(v[i].l, 500000) + u*v[i].l + v[i].m;
	return max(t1, t2);
}

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	cin >> n >> u >> d >> s;
	v.resize(n);

	for(int i=0; i<n; i++){
		cin >> v[i].t >> v[i].l >> v[i].m;
	}
	v.push_back({0, s, 0}); //시점
	v.push_back({500010, s, 0}); //종점
	sort(v.begin(), v.end());

	dp[0] = 0;
	update(0);

	for(int i=1; i<=n+1; i++){
		dp[i] = get(i);
		update(i);
	}

	cout << dp[n+1];
}
subtask 3. n≤300,000 / 같은 시간에 여러 개 시장 개최 가능 (100점)

이 서브테스크를 푸는데 250분이 걸렸습니다.

동시에 여러 개의 시장이 개최가 되면, 모두 갈 수는 있습니다.
(이 문제)와 마찬가지로 a와 b에서 동시에 시장이 개최되고 a와 b를 모두 방문한다면, a와 b 사이에 있는 시장은 모두 방문하는게 무조건 이득이라는 것을 알 수 있습니다. 그러므로 상인이 하루에 들린 시장들은 연속된 구간을 이룹니다.

구현을 하기에 앞서, 입력받은 데이터를 약간 다르게 저장을 해줍시다.

1
2
3
4
5
6
7
8
9
10
11
12
vector< pair<int, int> > v[500005];
for(int i=0; i<n; i++){
  int T, L, M; cin >> T >> L >> M;
  v[T].push_back({L, M});
  t = max(t, T);
}
v[0].push_back({s, 0});
v[++t].push_back({s, 0});

for(int i=0; i<=t; i++){
  sort(v[i].begin(), v[i].end());
}

각 날짜별로 다른 배열에 저장해두고, 같은 날짜라면 시장의 위치 순으로 정렬을 해주었습니다.

각 날짜별로 처리를 할건데, 두 가지 배열을 선언하도록 하겠습니다.

  1. dpl[j] = 상인이 해당 날짜에서 j번째 시장에 위치 / 상인은 j-1번째 시장에서 왔음
  2. dpr[j] = 상인이 해당 날짜에서 j번째 시장에 위치 / 상인은 j+1번째 시장에서 왔음

dpl[j]와 dpr[j]는 각각 dpl[j-1]과 dpr[j+1]에서 식을 유도해낼 수 있습니다.

j-1번째 시장의 위치를 prv, j+1번째 시장의 위치를 nxt라 할 때, 조금 더 구체적인 방법은 아래와 같습니다. (pos : 시장의 위치, cost : 해당 시장의 이득)

  1. dpl[j]와 dpr[j]는 max{ s1.query(0, pos)-d*pos, s2.query(pos, 500002)+u*pos } + cost
  2. j-1 >= 0이라면 dpl[j] = max{ dpl[j], dpl[j-1] - D*(pos-prv) + cost }
  3. j+1 < size라면 dpr[j] = max{ dpr[j], dpr[j+1] - U*(nxt-pos) + cost }
  4. 해당 날짜에서 j번째 시장까지 방문할 때 최대 이득은 max{ dpl[j], dpr[j] }

전체 코드는 아래와 같습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <bits/stdc++.h>
using namespace std;

typedef pair<int, int> p;

const int MAX_N = 1<<19;
struct SegTree {
	int seg[MAX_N*2-1];
	SegTree() {
		for(int i=0; i<MAX_N*2-1; i++) seg[i] = -1e9;
	}

	int query(int a, int b, int k=0, int l=0, int r=MAX_N){
    	if (b <= l || r <=a)return -1e9;
    	if (a<=l&&r<=b) return seg[k];
		return max(query(a,b,k*2+1,l,(l+r)/2), query(a,b,k*2+2,(l+r)/2,r));
	}

	void update(int k, int v) {
    	k += MAX_N-1;
    	seg[k] = v;
		while (k > 0) {
			k=(k-1)/2;
			seg[k]=max(seg[k*2+1],seg[k*2+2]);
		}
	}
};

SegTree s1, s2;
vector<p> v[500005];

int n, u, d, s, t;

inline int get(int pos){
	return max(s1.query(0, pos)-d*pos, s2.query(pos, 500002)+u*pos);
}

inline void update(int pos, int dp){
	s1.update(pos, dp + d * pos);
	s2.update(pos, dp - u * pos);
}

void step(int i){
	int num = v[i].size();
	if(!num) return;
	vector<int> dpl, dpr;
	for(int j=0; j<num; j++){
		int tmp = get(v[i][j].first) + v[i][j].second;
		dpl.push_back(tmp); dpr.push_back(tmp);
	}
	for(int j=1; j<num; j++){
		dpl[j] = max(dpl[j], dpl[j-1] - d * (v[i][j].first - v[i][j-1].first) + v[i][j].second);
	}
	for(int j=num-2; j>=0; j--){
		dpr[j] = max(dpr[j], dpr[j+1] - u * (v[i][j+1].first - v[i][j].first) + v[i][j].second);
	}
	for(int j=0; j<num; j++){
		int big = max(dpl[j], dpr[j]);
		update(v[i][j].first, big);
	}
}

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	cin >> n >> u >> d >> s;
	for(int i=0; i<n; i++){
		int a, b, c; cin >> a >> b >> c;
		v[a].push_back({b, c});
		t = max(t, a);
	}
	v[0].push_back({s, 0});
	v[++t].push_back({s, 0});

	for(int i=0; i<=t; i++){
		sort(v[i].begin(), v[i].end());
	}

	s1.update(s, d*s);
	s2.update(s, -u*s);

	for(int i=1; i<=t; i++){
		step(i);
	}

	cout << get(s);
}