백준13517 트리와 쿼리 8

문제 링크

  • http://icpc.me/13517

사용 알고리즘

  • 모스 알고리즘

시간복잡도

  • $O((N+Q) \sqrt N log N)$

풀이

전체 구간에서 k번째 원소를 $O(log N)$에 찾는 세그먼트 트리는 잘 알려져 있습니다.
트리 위에서 모스 알고리즘을 돌리면서, k번째 원소를 찾는 세그먼트 트리를 이용해주면 $O((N+Q) \sqrt N log N)$에 풀 수 있습니다.

TLE가 날 수 있는데, 가중치를 좌표압축해주고 모스 알고리즘 쿼리 처리 순서를 Hilbert Curve(링크)를 이용해 정해주면 TLE를 피할 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#pragma GCC optimize("O3")
#include <bits/stdc++.h>
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
using namespace std;

const int sq = 400;

inline int64_t hilbertOrder(int x, int y, int pow, int rotate) {
	if (pow == 0) {
		return 0;
	}
	int hpow = 1 << (pow-1);
	int seg = (x < hpow) ? (
		(y < hpow) ? 0 : 3
	) : (
		(y < hpow) ? 1 : 2
	);
	seg = (seg + rotate) & 3;
	const int rotateDelta[4] = {3, 0, 0, 1};
	int nx = x & (x ^ hpow), ny = y & (y ^ hpow);
	int nrot = (rotate + rotateDelta[seg]) & 3;
	int64_t subSquareSize = int64_t(1) << (2*pow - 2);
	int64_t ans = seg * subSquareSize;
	int64_t add = hilbertOrder(nx, ny, pow-1, nrot);
	ans += (seg == 1 || seg == 2) ? add : (subSquareSize - add - 1);
	return ans;
}

struct Query{
    int s, e, x, l, k;
    long long order;
    void init(){
        order = hilbertOrder(s, e, 21, 0);
    }
    Query(int s, int e, int x, int l, int k) : s(s), e(e), x(x), l(l), k(k) {}
    bool operator < (const Query &t) const {
        return order < t.order;
    }
};

struct Seg{
    int tree[1 << 18];
    int bias = 1 << 17;
    void update(int x, int v){
        x |= bias; tree[x] += v;
        while(x >>= 1){
            tree[x] = tree[x << 1] + tree[x << 1 | 1];
        }
    }
    int query(int k){
        int now = 1;
        while(now < bias){
            if(k <= tree[now << 1]) now = now << 1;
            else k -= tree[now << 1], now = now << 1 | 1;
        }
        return now - bias;
    }
}seg;

int n, q;
int arr[101010];
vector<int> comp;
vector<int> g[101010];
int dep[101010], dp[18][101010];
vector<Query> qry;
int in[101010], out[101010], id[202020], pv;
int chk[101010];
int ans[101010];

void dfs(int v, int p){
    in[v] = ++pv; id[pv] = v;
    for(auto i : g[v]){
        if(i == p) continue;
        dp[0][i] = v; dep[i] = dep[v] + 1;
        dfs(i, v);
    }
    out[v] = ++pv;
    id[pv] = v;
}

int lca(int u, int v){
    if(dep[u] > dep[v]) swap(u, v);
    int diff = dep[v] - dep[u];
    for(int i=0; diff; i++){
        if(diff & 1) v = dp[i][v];
        diff >>= 1;
    }
    if(u == v) return u;
    for(int i=17; i>=0; i--){
        if(dp[i][u] ^ dp[i][v]) u = dp[i][u], v = dp[i][v];
    }
    return dp[0][v];
}

void update(int x){
    x = id[x];
    chk[x] ^= 1;
    if(chk[x]) seg.update(arr[x], 1);
    else seg.update(arr[x], -1);
}

int main(){
    ios_base::sync_with_stdio(0); cin.tie(0);
    cin >> n; comp.reserve(n);
    for(int i=1; i<=n; i++){
        cin >> arr[i]; comp.push_back(arr[i]);
    }
    for(int i=1; i<n; i++){
        int s, e; cin >> s >> e;
        g[s].push_back(e); g[e].push_back(s);
    }
    compress(comp);
    for(int i=1; i<=n; i++){
        arr[i] = lower_bound(all(comp), arr[i]) - comp.begin();
    }
    dfs(1, -1);
    for(int j=1; j<18; j++) for(int i=1; i<=n; i++){
        dp[j][i] = dp[j-1][dp[j-1][i]];
    }

    cin >> q; qry.reserve(q);
    for(int i=0; i<q; i++){
        int a, b, k; cin >> a >> b >> k;
        if(in[a] > in[b]) swap(a, b);
        int l = lca(a, b);
        if(l == a) qry.emplace_back(in[a], in[b], i, -1, k);
        else qry.emplace_back(out[a], in[b], i, in[l], k);
        qry.back().init();
    }
    sort(all(qry));

    int s = qry[0].s, e = qry[0].s-1;
    for(int i=0; i<q; i++){
        while(qry[i].s < s) update(--s);
        while(e < qry[i].e) update(++e);
        while(s < qry[i].s) update(s++);
        while(qry[i].e < e) update(e--);

        if(qry[i].l != -1) update(qry[i].l);
        ans[qry[i].x] = seg.query(qry[i].k);
        if(qry[i].l != -1) update(qry[i].l);
    }

    for(int i=0; i<q; i++) cout << comp[ans[i]] << "\n";
}