백준16901 XOR MST

문제 링크

  • http://icpc.me/16901

사용 알고리즘

  • 분할 정복

시간복잡도

  • $O(N \log N \log 2^{30})$

풀이

가중치를 30비트로 표현할 수 있고 XOR로 표현되는 가중치를 최소화시키는 것이 목적이므로, Binary Trie를 만든 뒤 30번째 비트부터 쭉 봅시다.

30번째 비트가 켜져있는 정점들의 집합과 꺼져있는 정점들의 집합을 잇는 간선은 한 개 있는 것이 최적입니다. 켜져있는 집합의 원소들로 Binary Trie를 구성한 뒤, 꺼져있는 원소들과 XOR했을 때 최소가 되는 값을 Trie로 찾아줄 수 있습니다. 가중치가 최소인 간선을 하나 찾아서 MST에 추가해줍시다.

각 집합에 있는 원소들을 29번째 비트 기준으로 보고 분할하고, 28번째 비트 기준으로 보고 분할하고… 반복해주면 문제를 풀 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
using namespace std;

typedef long long ll;
const int MXD = 30;

struct Node{
    int l, r;
    Node() : Node(-1, -1) {}
    Node(int l, int r) : l(l), r(r) {}
};
vector<Node> nd;

int n;
int a[202020];

void insert(int x, int node = 1, int dep = MXD){
    if(dep == -1) return;
    if((x >> dep) & 1){
        if(nd[node].r == -1) nd[node].r = nd.size(), nd.emplace_back();
        insert(x, nd[node].r, dep-1);
    }else{
        if(nd[node].l == -1) nd[node].l = nd.size(), nd.emplace_back();
        insert(x, nd[node].l, dep-1);
    }
}
// get min
int query(int x, int node = 1, int dep = MXD){
    if(dep == -1) return 0;
    if((x >> dep) & 1){
        if(nd[node].r != -1) return query(x, nd[node].r, dep-1);
        return (1 << dep) + query(x, nd[node].l, dep-1);
    }else{
        if(nd[node].l != -1) return query(x, nd[node].l, dep-1);
        return (1 << dep) + query(x, nd[node].r, dep-1);
    }
}

ll ans = 0;
void solve(int s, int e, int dep = MXD){
    if(dep == -1 || s >= e) return;

    int m = s; ll res = 1e18;
    while(m <= e && !(a[m] & (1 << dep))) m++; m--;
    solve(s, m, dep-1); solve(m+1, e, dep-1);

    if(s > m || m+1 > e) return;
    nd.clear(); nd.resize(2);
    for(int i=s; i<=m; i++) insert(a[i]);
    for(int i=m+1; i<=e; i++) res = min<ll>(res, query(a[i]));
    ans += res;
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n; for(int i=1; i<=n; i++) cin >> a[i]; sort(a+1, a+n+1);
    solve(1, n);
    cout << ans;
}