백준13519 트리와 쿼리 10

문제 링크

  • http://icpc.me/13519

사용 알고리즘

  • HLD
  • 세그 레이지

시간복잡도

  • $O(Q log^2 N)$

풀이

세그먼트 트리의 각 노드에서 구간의 왼쪽을 포함하는 최댓값, 구간의 오른쪽을 포함하는 최댓값, 구간의 최댓값, 구간의 합을 저장하고 있으면 구간의 최대 연속합을 구해줄 수 있습니다.
이 세그먼트 트리에 lazy Propagation을 얹어주고 hld를 돌려주면 문제를 풀 수 있습니다.

hld를 이용해 답을 구할 때, 각 체인의 답들을 합치는 부분을 구현하는 것이 많이 까다로울 수 있습니다. 그림을 그려가면서 구현 설계하는 것을 추천합니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <bits/stdc++.h>
#define x first
#define y second
using namespace std;

typedef long long ll;
typedef pair<int, int> pii;
const int inf = 1e9;

int n, q;
struct Node{
	ll l, r, mx, sum;
    Node(){ l = r = mx = sum = 0; }
    Node(ll l, ll r, ll mx, ll sum) : l(l), r(r), mx(mx), sum(sum) {}
    void go(){
        l = max<ll>(0, l);
        r = max<ll>(0, r);
        mx = max<ll>(0, mx);
    }
};

Node tree[1 << 18];
int tmp[1 << 18];
const int base = 1 << 17;

Node merge(Node a, Node b){
	Node ret;
	ret.l = max(a.l, a.sum + b.l);
	ret.r = max(a.r + b.sum, b.r);
	ret.sum = a.sum + b.sum;
	ret.mx = max({a.mx, b.mx, a.r + b.l});
    ret.go();
	return ret;
}

void push(int node, int s, int e){
    if(tmp[node] == inf) return;
    tree[node].l = tree[node].r = tree[node].mx = tree[node].sum = (e-s+1) * tmp[node];
    tree[node].go();
    if(s ^ e){
        tmp[node << 1] = tmp[node];
        tmp[node << 1 | 1] = tmp[node];
    }
    tmp[node] = inf;
}

void seg_update(int l, int r, int v, int node = 1, int s = 1, int e = n){
    push(node, s, e);
    if(r < s || e < l) return;
    if(l <= s && e <= r){
        tmp[node] = v;
        push(node, s, e);
        return;
    }
    int m = s + e >> 1;
    seg_update(l, r, v, node << 1, s, m);
    seg_update(l, r, v, node << 1 | 1, m+1, e);
    tree[node] = merge(tree[node << 1], tree[node << 1 | 1]);
}

Node seg_query(int l, int r, int node = 1, int s = 1, int e = n){
    push(node, s, e);
    if(r < s || e < l) return Node(0, 0, 0, 0);
    if(l <= s && e <= r) return tree[node];
    int m = s + e >> 1;
    Node t1 = seg_query(l, r, node << 1, s, m);
    Node t2 = seg_query(l, r, node << 1 | 1, m+1, e);
    return merge(t1, t2);
}

int arr[101010];
int top[101010], in[101010], dep[101010], sz[101010], par[101010], pv;
vector<int> g[101010], inp[101010];
void dfs(int v = 1, int p = -1){
    for(auto i : inp[v]){
        if(i == p) continue;
        par[i] = v; g[v].push_back(i);
        dfs(i, v);
    }
}
void dfs1(int v = 1){
    sz[v] = 1;
    for(auto &i : g[v]){
        dep[i] = dep[v] + 1;
        dfs1(i); sz[v] += sz[i];
        if(sz[i] > sz[g[v][0]]) swap(i, g[v][0]);
    }
}
void dfs2(int v = 1){
    in[v] = ++pv;
    for(auto i : g[v]){
        top[i] = i == g[v][0] ? top[v] : i;
        dfs2(i);
    }
}

void update(int s, int e, int x){
    while(top[s] ^ top[e]){
        if(dep[top[s]] < dep[top[e]]) swap(s, e);
        seg_update(in[top[s]], in[s], x);
        s = par[top[s]];
    }
    if(dep[s] > dep[e]) swap(s, e);
    seg_update(in[s], in[e], x);
}

ll query(int s, int e){
    Node t1, t2;
    if(in[s] > in[e]) swap(s, e);
    while(top[s] ^ top[e]){
        if(dep[top[s]] > dep[top[e]]){
            t1 = merge(seg_query(in[top[s]], in[s]), t1);
            s = par[top[s]];
        }else{
            t2 = merge(seg_query(in[top[e]], in[e]), t2);
            e = par[top[e]];
        }
    }
    if(dep[s] > dep[e]) t1 = merge(seg_query(in[e], in[s]), t1);
    else t2 = merge(seg_query(in[s], in[e]), t2);
    swap(t1.l, t1.r);
    return merge(t1, t2).mx;
}

int main(){
    ios_base::sync_with_stdio(0); cin.tie(0);
    cin >> n; top[1] = 1;
    for(int i=1; i<=n; i++) cin >> arr[i];
    for(int i=1; i<n; i++){
        int s, e; cin >> s >> e;
        inp[s].push_back(e); inp[e].push_back(s);
    }
    dfs(); dfs1(); dfs2();
    for(int i=0; i<(1 << 18); i++) tmp[i] = inf;
    for(int i=1; i<=n; i++) seg_update(in[i], in[i], arr[i]);

    cin >> q;
    while(q--){
        int op, s, e; cin >> op >> s >> e;
        if(op == 1){
            cout << query(s, e) << "\n";
        }else{
            int x; cin >> x;
            update(s, e, x);
        }
    }
}