백준17960 3차원 점과 쿼리

문제 링크

  • https://icpc.me/17960

사용 알고리즘

  • 세그먼트 트리
  • PST

시간복잡도

  • $O((N+Q) \log^2 N)$

풀이

2차원이면 PST를 이용해 문제를 해결할 수 있습니다. 하지만 이 문제는 3차원입니다.

3차원 공간을 $z$좌표에 따라 2차원 평면 $N$개로 분할합시다. 각 평면은 2차원이므로 PST를 이용해 쿼리를 처리할 수 있습니다.
3차원 쿼리는 연속한 몇 개의 평면에 대한 쿼리를 처리해야 합니다. 이는 평면을 관리하는 PST를 세그먼트 트리로 관리하면 $O(\log^2 N)$ 시간에 해결할 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
constexpr ll MOD = 1e9+1;

struct Point{
    int x, y, z;
    Point() = default;
    Point(int x, int y, int z) : x(x), y(y), z(z) {}
};

struct RectangleCount{
    struct Node{
        int l, r, v;
        Node() : Node(0) {}
        Node(int v) : l(0), r(0), v(v) {}
    };

    vector<int> cx, cy, root;
    vector<Node> tree;

    void init(vector<Point> points){
        sort(points.begin(), points.end(), [&](auto p1, auto p2){ return p1.x < p2.x; });

        cx.reserve(points.size()); cy.reserve(points.size());
        for(const auto &p : points) cx.push_back(p.x), cy.push_back(p.y);
        sort(cx.begin(), cx.end()); cx.erase(unique(cx.begin(), cx.end()), cx.end());
        sort(cy.begin(), cy.end()); cy.erase(unique(cy.begin(), cy.end()), cy.end());
        for(auto &p : points){
            p.x = lower_bound(cx.begin(), cx.end(), p.x) - cx.begin() + 1;
            p.y = lower_bound(cy.begin(), cy.end(), p.y) - cy.begin() + 1;
        }

        int h = __lg(cy.size()) + 1;
        root.reserve(cx.size() + 1);
        tree.reserve(cx.size() * (h + 2));

        tree.emplace_back();
        tree.emplace_back();
        root.push_back(1);

        for(int i=0; i<points.size(); i++){
            int prv = root.back(), now = tree.size();
            tree.emplace_back();
            if(i == 0 || points[i-1].x != points[i].x) root.push_back(now);
            else root.back() = now;
            update(prv, now, 0, cy.size(), points[i].y, 1);
        }
    }

    void update(int prv, int now, int s, int e, int x, int v){
        if(s == e){ tree[now].v = tree[prv].v + v; return; }
        int m = (s + e) / 2;
        if(x <= m){
            if(tree[now].l == 0 || tree[now].l == tree[prv].l) tree[now].l = tree.size(), tree.emplace_back();
            if(tree[now].r == 0) tree[now].r = tree[prv].r;
            update(tree[prv].l, tree[now].l, s, m, x, v);
        }
        else{
            if(tree[now].r == 0 || tree[now].r == tree[prv].r) tree[now].r = tree.size(), tree.emplace_back();
            if(tree[now].l == 0) tree[now].l = tree[prv].l;
            update(tree[prv].r, tree[now].r, m+1, e, x, v);
        }
        tree[now].v = tree[tree[now].l].v + tree[tree[now].r].v;
    }

    int query(int prv, int now, int s, int e, int l, int r){
        if(r < s || e < l || prv == now || now == 0) return 0;
        if(l <= s && e <= r) return tree[now].v - tree[prv].v;
        int m = (s + e) / 2;
        return query(tree[prv].l, tree[now].l, s, m, l, r) + query(tree[prv].r, tree[now].r, m+1, e, l, r);
    }

    int query(int x1, int x2, int y1, int y2){
        int xs = lower_bound(cx.begin(), cx.end(), x1) - cx.begin() + 1;
        int xe = upper_bound(cx.begin(), cx.end(), x2) - cx.begin();
        int ys = lower_bound(cy.begin(), cy.end(), y1) - cy.begin() + 1;
        int ye = upper_bound(cy.begin(), cy.end(), y2) - cy.begin();
        return query(root[xs-1], root[xe], 0, cy.size(), ys, ye);
    }
};

int N, Q;
vector<Point> V;
vector<int> Z;

constexpr int SZ = 1 << 17;
vector<Point> P[SZ << 1];
RectangleCount T[SZ << 1];

void AddPoint(int x, const Point &v){
    for(x|=SZ; x; x>>=1) P[x].push_back(v);
}

int Query(int l, int r, int x1, int x2, int y1, int y2){
    l |= SZ; r |= SZ; int ret = 0;
    while(l <= r){
        if(l & 1) ret += T[l++].query(x1, x2, y1, y2);
        if(~r & 1) ret += T[r--].query(x1, x2, y1, y2);
        l >>= 1; r >>= 1;
    }
    return ret;
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> N >> Q; V.resize(N); Z.reserve(N);
    for(auto &[x,y,z] : V) cin >> x >> y >> z, Z.push_back(z);
    sort(Z.begin(), Z.end()); Z.erase(unique(Z.begin(), Z.end()), Z.end());
    for(auto &[x,y,z] : V) z = lower_bound(Z.begin(), Z.end(), z) - Z.begin() + 1;
    for(int i=0; i<N; i++) AddPoint(V[i].z, V[i]);
    for(int i=0; i<SZ*2; i++) T[i].init(P[i]);

    ll lst = 0;
    for(int q=0; q<Q; q++){
        ll a, b, c, d, e, f; cin >> a >> b >> c >> d >> e >> f;
        int x1 = (a ^ lst) % MOD, y1 = (b ^ lst) % MOD, z1 = (c ^ lst) % MOD;
        int x2 = (d ^ lst) % MOD, y2 = (e ^ lst) % MOD, z2 = (f ^ lst) % MOD;
        if(x1 > x2) swap(x1, x2);
        if(y1 > y2) swap(y1, y2);
        if(z1 > z2) swap(z1, z2);

        z1 = lower_bound(Z.begin(), Z.end(), z1) - Z.begin() + 1;
        z2 = upper_bound(Z.begin(), Z.end(), z2) - Z.begin();

        ll res = Query(z1, z2, x1, x2, y1, y2);
        cout << res << "\n";
        lst += res;
    }
}