-
Notifications
You must be signed in to change notification settings - Fork 521
/
SegmentTree.java
174 lines (154 loc) · 4.5 KB
/
SegmentTree.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
package structures;
import java.util.function.Predicate;
public class SegmentTree {
int n;
Node[] tree;
public static class Node {
// initial values for leaves
public long mx = 0;
public long sum = 0;
public long add = 0;
void apply(int l, int r, long v) {
mx += v;
sum += v * (r - l + 1);
add += v;
}
}
static Node unite(Node a, Node b) {
Node res = new Node();
res.mx = Math.max(a.mx, b.mx);
res.sum = a.sum + b.sum;
return res;
}
void push(int x, int l, int r) {
int m = (l + r) >> 1;
int y = x + ((m - l + 1) << 1);
if (tree[x].add != 0) {
tree[x + 1].apply(l, m, tree[x].add);
tree[y].apply(m + 1, r, tree[x].add);
tree[x].add = 0;
}
}
void pull(int x, int y) {
tree[x] = unite(tree[x + 1], tree[y]);
}
SegmentTree(int n) {
this.n = n;
tree = new Node[2 * n - 1];
for (int i = 0; i < tree.length; i++) tree[i] = new Node();
build(0, 0, n - 1);
}
SegmentTree(long[] v) {
n = v.length;
tree = new Node[2 * n - 1];
for (int i = 0; i < tree.length; i++) tree[i] = new Node();
build(0, 0, n - 1, v);
}
void build(int x, int l, int r) {
if (l == r) {
return;
}
int m = (l + r) >> 1;
int y = x + ((m - l + 1) << 1);
build(x + 1, l, m);
build(y, m + 1, r);
pull(x, y);
}
void build(int x, int l, int r, long[] v) {
if (l == r) {
tree[x].apply(l, r, v[l]);
return;
}
int m = (l + r) >> 1;
int y = x + ((m - l + 1) << 1);
build(x + 1, l, m, v);
build(y, m + 1, r, v);
pull(x, y);
}
public Node get(int ll, int rr) {
return get(ll, rr, 0, 0, n - 1);
}
Node get(int ll, int rr, int x, int l, int r) {
if (ll <= l && r <= rr) {
return tree[x];
}
int m = (l + r) >> 1;
int y = x + ((m - l + 1) << 1);
push(x, l, r);
Node res;
if (rr <= m) {
res = get(ll, rr, x + 1, l, m);
} else {
if (ll > m) {
res = get(ll, rr, y, m + 1, r);
} else {
res = unite(get(ll, rr, x + 1, l, m), get(ll, rr, y, m + 1, r));
}
}
pull(x, y);
return res;
}
void modify(int ll, int rr, long v) {
modify(ll, rr, v, 0, 0, n - 1);
}
void modify(int ll, int rr, long v, int x, int l, int r) {
if (ll <= l && r <= rr) {
tree[x].apply(l, r, v);
return;
}
int m = (l + r) >> 1;
int y = x + ((m - l + 1) << 1);
push(x, l, r);
if (ll <= m) {
modify(ll, rr, v, x + 1, l, m);
}
if (rr > m) {
modify(ll, rr, v, y, m + 1, r);
}
pull(x, y);
}
// calls all FALSE elements to the left of the sought position exactly once
int findFirst(int ll, int rr, Predicate<Node> f) {
return findFirst(ll, rr, f, 0, 0, n - 1);
}
int findFirst(int ll, int rr, Predicate<Node> f, int x, int l, int r) {
if (ll <= l && r <= rr && !f.test(tree[x])) {
return -1;
}
if (l == r) {
return l;
}
push(x, l, r);
int m = (l + r) >> 1;
int y = x + ((m - l + 1) << 1);
int res = -1;
if (ll <= m) {
res = findFirst(ll, rr, f, x + 1, l, m);
}
if (rr > m && res == -1) {
res = findFirst(ll, rr, f, y, m + 1, r);
}
pull(x, y);
return res;
}
// Returns min(p | p<=rr && sum[ll..p]>=sum). If no such p exists, returns -1
static int sumLowerBound(SegmentTree t, int ll, int rr, long sum) {
long[] sumSoFar = new long[1];
return t.findFirst(ll, rr, node -> {
if (sumSoFar[0] + node.sum >= sum)
return true;
sumSoFar[0] += node.sum;
return false;
});
}
// Usage example
public static void main(String[] args) {
SegmentTree t = new SegmentTree(10);
t.modify(1, 2, 10);
t.modify(2, 3, 20);
System.out.println(30 == t.get(1, 3).mx);
System.out.println(60 == t.get(1, 3).sum);
SegmentTree tt = new SegmentTree(new long[] {2, 1, 10, 20});
System.out.println(2 == sumLowerBound(tt, 0, tt.n - 1, 12));
}
}