백준25491 Mexor tree

문제 링크

  • https://icpc.me/25491

사용 알고리즘

  • 세그먼트 트리
  • LCA

시간복잡도

  • $O((N+Q) \log N)$

풀이

트리의 두 정점 $u, v$를 잇는 경로 상의 정점들에 어떤 값 $w$를 xor하는 것이 아니라 더하는 문제를 먼저 생각해 봅시다. 이런 문제는 $A[u]$와 $A[v]$에 각각 $w$를 더하고 $A[par(LCA(u,v))]$에서 $2w$를 뺀 다음, 각 정점을 루트로 하는 서브 트리의 합을 구하면 됩니다.
$A[u]$와 $A[v]$에 $w$를 더한 것은 각각 $u$의 조상과 $v$의 조상을 따라 전파되는데, 이 값이 $LCA(u, v)$ 바깥으로 나가지 않게 하기 위해서 $A[par(LCA(u,v))]$에 $-2w$를 더하는 것입니다.
XOR도 비슷하게 처리할 수 있습니다. $A[u], A[v], A[LCA(u,v)], A[par(LCA(u,v))]$에 각각 $w$를 xor하면 됩니다.

이제, 루트 정점에서 다른 정점까지 가는 경로 상의 정점 중 mex값을 구해야 합니다. 원소가 추가/삭제되는 상황에서 mex값을 빠르게 구할 수 있는 자료구조 $T$가 있다고 합시다. 문제의 답을 구하는 것은, DFS를 하면서 정점 $v$에 들어갈 때 $T\leftarrow T\cup \left{A[v]\right}$, 정점 $v$를 빠져나올 때 $T \leftarrow T \setminus \left{A[v]\right}$를 수행하면 됩니다. 이런 연산을 빠르게 처리할 수 있는 자료구조 $T$를 만들어야 합니다.

구간의 합을 세그먼트 트리를 만듭시다. 자료구조 $T$에 $x$를 삽입할 때는 $C[x]$를 1 증가시키고, 삭제할 때는 1 감소시킬 것입니다. 그리고 세그먼트 트리의 $x$번째 리프 정점의 값은 $C[x] > 0$이면 1, $C[x] < 0$이면 0으로 설정합시다. 즉, 세그먼트 트리에서는 구간에서 등장하는 수들의 종류를 관리합니다.
이렇게 세그먼트 트리를 관리하면 $T$에 들어간 원소들의 mex는 단순히 세그먼트 트리를 타로 내려가는 것으로 $O(\log N)$ 시간에 구할 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <bits/stdc++.h>
using namespace std;
constexpr int SZ = 1 << 19;

int N, S, Q, A[303030], B[303030], R[303030], D[303030], P[22][303030], T[SZ << 1], C[SZ << 1];
vector<int> G[303030];

void DFS(int v, int b=-1){
    for(auto i : G[v]) if(i != b) P[0][i] = v, D[i] = D[v] + 1, DFS(i, v);
}

void Build(int v, int b=-1){
    for(auto i : G[v]) if(i != b) Build(i, v), B[v] ^= B[i];
}

int LCA(int u, int v){
    if(D[u] < D[v]) swap(u, v);
    int diff = D[u] - D[v];
    for(int i=0; diff; i++, diff>>=1) if(diff & 1) u = P[i][u];
    if(u == v) return u;
    for(int i=21; i>=0; i--) if(P[i][u] != P[i][v]) u = P[i][u], v = P[i][v];
    return P[0][u];
}

void Update(int x, int v){
    x |= SZ; C[x] += v; T[x] = C[x] > 0;
    while(x >>= 1) T[x] = T[x<<1] + T[x<<1|1];
}

int Query(){
    int node = 1, sz = SZ;
    while(node < SZ){
        node = node << 1 | (T[node<<1]*2 == sz);
        sz >>= 1;
    }
    return node ^ SZ;
}

void Solve(int v, int b=-1){
    Update(A[v] ^ B[v], +1);
    R[v] = Query();
    for(auto i : G[v]) if(i != b) Solve(i, v);
    Update(A[v] ^ B[v], -1);
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> N >> S;
    for(int i=1; i<=N; i++) cin >> A[i];
    for(int i=1,s,e; i<N; i++) cin >> s >> e, G[s].push_back(e), G[e].push_back(s);
    DFS(S);
    for(int i=1; i<22; i++) for(int j=1; j<=N; j++) P[i][j] = P[i-1][P[i-1][j]];
    cin >> Q;
    for(int i=1; i<=Q; i++){
        int u, v, w; cin >> u >> v >> w;
        int l = LCA(u, v);
        B[u] ^= w; B[v] ^= w; B[l] ^= w; B[P[0][l]] ^= w;
    }
    Build(S);
    Solve(S);
    for(int i=1; i<=N; i++) cout << R[i] << " ";
}