백준10717 성벽

문제 링크

  • http://icpc.me/10717

문제 출처

  • 2015 JOI 5번

사용 알고리즘

  • 세그먼트 트리

시간복잡도

  • $O(NM + (N+M)\log(N+M))$

풀이

좌측 상단 끝점과 우측 하단 끝점만 고정한다면 성벽은 한 가지로 특정됩니다. 북서쪽과 동남쪽 방향으로 그어진 N+M개의 대각선에 대해 독립적으로 문제를 해결해도 됩니다.
성벽을 (1) 북서쪽에서 뻗어나가는 성벽과 (2) 남동쪽에서 뻗어나가는 성벽, 2개로 나눠서 생각할 것입니다.


위 그림처럼 남동쪽 성벽을 쭉 늘어놓고, 대각선 위의 각 점이 성벽의 북서쪽 점이 되는 경우의 수를 구해주면 됩니다. Inversion Counting과 비슷한 방법으로 구할 수 있다는 생각이 듭니다. 자세한 풀이를 알아봅시다.

먼저, 각 점에서 상하좌우 네 방향으로 뻗어나갈 수 있는 최대 길이를 구하고, 각각 U(i, j), D(i, j), L(i, j), R(i, j)라고 합시다.
간단한 DP로 $O(NM)$에 구할 수 있습니다.

이렇게 전처리한 값을 이용해서, 각 대각선에 대해 경우의 수를 구해봅시다.

대각선 위에 있는 모든 점 (i, j)에 대해서, 각 점이 성벽의 남동쪽 점일 때 성벽이 어디까지 뻗을 수 있는지 구합시다. $i - \min(U(i, j), L(i, j)) + 1$입니다. $(i - \min(U(i, j), L(i, j)) + 1, i)$ pair를 정렬해서 배열로 저장합시다.
점 (i, j)가 북서쪽 끝점이면 성벽은 $i + \min(D(i, j), R(i, j)) + 1$까지 뻗을 수 있습니다.
이 정보를 이용해서 Fenwick Tree 등으로 문제의 정답을 구해주면 됩니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
using namespace std;

typedef long long ll;
typedef pair<int, int> p;
const int S = 4040;

int tree[S];
void init(){ memset(tree, 0, sizeof tree); }
void update(int x, int v){ for(x++; x<S; x+=x&-x) tree[x] += v; }
int query(int x){ int ret = 0; for(x++; x; x-=x&-x) ret += tree[x]; return ret; }
int query(int s, int e){ return s <= e ? query(e) - query(s-1) : 0; }

int n, m, k, l; ll ans; bool a[S][S];
int north[S][S], south[S][S], west[S][S], east[S][S];

void solve(int si, int sj){
    init(); vector<p> range; int pv = 0;
    for(int i=si, j=sj; i<=n && j<=m; i++, j++) range.emplace_back(i-min(north[i][j], west[i][j])+1, i);
    sort(all(range));

    for(int i=si, j=sj; i<=n && j<=m; i++, j++){
        while(pv < range.size() && range[pv].x <= i) update(range[pv++].y, 1);
        ans += query(i+l-1, i+min(south[i][j], east[i][j])-1);
    }
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n >> m >> l >> k; memset(a, 1, sizeof a);
    for(int i=0; i<k; i++){ int x, y; cin >> x >> y; a[x][y] = 0; }

    for(int i=1; i<=n; i++) for(int j=1; j<=m; j++) if(a[i][j]) {
        north[i][j] = north[i-1][j] + 1;
        west[i][j] = west[i][j-1] + 1;
    }
    for(int i=n; i>=0; i--) for(int j=m; j>=1; j--) if(a[i][j]) {
        south[i][j] = south[i+1][j] + 1;
        east[i][j] = east[i][j+1] + 1;
    }

    for(int i=1; i<=n; i++) solve(i, 1);
    for(int j=2; j<=m; j++) solve(1, j);
    cout << ans;
}