Theory
Implementation
// All intervals are half-open [a, b). The ID of the first node is 1.
class SegmentTree {
private:
vector<int> tree;
int n;
/*
* node - the segment tree index
* values - the original array
* left - the segment tree's left limit (inclusive)
* right - the segment tree's right limit (exclusive)
*/
void build(int node, vector<int>& values, int left, int right) {
if (left + 1 == right) {
tree[node] = values[left];
return;
}
int leftNode = 2 * node;
int rightNode = 2 * node + 1;
int mid = left + (right - left) / 2;
build(leftNode, values, left, mid);
build(rightNode, values, mid, right);
tree[node] = tree[leftNode] + tree[rightNode];
}
/*
* node - the segment tree index
* left - the segment tree's left limit (inclusive)
* right - the segment tree's right limit (exclusive)
* index - the position in the array to update
* value - the new value for the position
*/
void update(int node, int left, int right, int index, int value) {
if (left + 1 == right) {
tree[node] = value;
return;
}
int leftNode = 2 * node;
int rightNode = 2 * node + 1;
int mid = left + (right - left) / 2;
if (index < mid) {
update(leftNode, left, mid, index, value);
} else {
update(rightNode, mid, right, index, value);
}
tree[node] = tree[leftNode] + tree[rightNode];
}
/*
* node - the segment tree index
* left - the segment tree's left limit (inclusive)
* right - the segment tree's right limit (exclusive)
* a - the query's left limit (inclusive)
* b - the query's right limit (exclusive)
*/
int query(int node, int left, int right, int a, int b) {
if (right <= a || b <= left) { return 0; }
if (a <= left && right <= b) {
return tree[node];
}
int mid = left + (right - left) / 2;
int leftNode = 2 * node;
int rightNode = 2 * node + 1;
if (b <= mid) {
return query(leftNode, left, mid, a, b);
} else if (a >= mid) {
return query(rightNode, mid, right, a, b);
}
int leftQuery = query(leftNode, left, mid, a, mid);
int rightQuery = query(rightNode, mid, right, mid, b);
// Merge operation. In this case, simply the sum of both subtrees, but
// it can be something else, e.g., (min(L, R), max(L, R), etc.).
return leftQuery + rightQuery;
}
public:
SegmentTree() : n(0) {}
// Initializes the segment tree from an array of values.
SegmentTree(vector<int>& values) {
n = values.size();
tree.resize(4 * n);
build(1, values, 0, n);
}
// updates the array at `index`.
void update(int index, int value) {
update(1, 0, n, index, value);
}
// returns the sum from a to b (inclusive).
int query(int a, int b) {
return query(1, 0, n, a, b + 1);
}
};
With lazy propagation
// All intervals are half-open [a, b). The ID of the first node is 1.
class SegmentTreeWithLazyPropagation {
private:
vector<int> tree;
vector<int> lazy;
int n;
/*
* node - the segment tree index
* values - the original array
* left - the segment tree's left limit (inclusive)
* right - the segment tree's right limit (exclusive)
*/
void build(int node, vector<int>& values, int left, int right) {
if (left + 1 == right) {
tree[node] = values[left];
return;
}
int leftNode = 2 * node;
int rightNode = 2 * node + 1;
int mid = left + (right - left) / 2;
build(leftNode, values, left, mid);
build(rightNode, values, mid, right);
tree[node] = tree[leftNode] + tree[rightNode];
}
/*
* node - the segment tree index
* left - the segment tree's left limit (inclusive)
* right - the segment tree's right limit (exclusive)
* a - the query's left limit (inclusive)
* b - the query's right limit (exclusive)
* value - the amount to add to all the elements in the interval.
*/
void update(int node, int left, int right, int a, int b, int value) {
int leftNode = 2 * node;
int rightNode = 2 * node + 1;
if (lazy[node] != 0) {
tree[node] += (right - left) * lazy[node];
if (left + 1 != right) {
lazy[leftNode] += lazy[node];
lazy[rightNode] += lazy[node];
}
lazy[node] = 0;
}
if (left >= right || b <= left || right <= a) {
return;
}
if (a <= left && right <= b) {
tree[node] += (right - left) * value;
if (left + 1 != right) {
lazy[leftNode] += value;
lazy[rightNode] += value;
}
return;
}
int mid = left + (right - left) / 2;
update(leftNode, left, mid, a, b, value);
update(rightNode, mid, right, a, b, value);
tree[node] = tree[leftNode] + tree[rightNode];
}
/*
* node - the segment tree index
* left - the segment tree's left limit (inclusive)
* right - the segment tree's right limit (exclusive)
* a - the query's left limit (inclusive)
* b - the query's right limit (exclusive)
*/
int query(int node, int left, int right, int a, int b) {
if (left >= right || b <= left || right <= a) {
return 0;
}
int leftNode = 2 * node;
int rightNode = 2 * node + 1;
if (lazy[node] != 0) {
tree[node] += (right - left) * lazy[node];
if (left + 1 != right) {
lazy[leftNode] += lazy[node];
lazy[rightNode] += lazy[node];
}
lazy[node] = 0;
}
if (a <= left && right <= b) {
return tree[node];
}
int mid = left + (right - left) / 2;
int leftQuery = query(leftNode, left, mid, a, b);
int rightQuery = query(rightNode, mid, right, a, b);
return leftQuery + rightQuery;
}
public:
SegmentTreeWithLazyPropagation(): n(0) {}
SegmentTreeWithLazyPropagation(vector<int>& values) {
n = values.size();
tree.resize(4 * n);
lazy.resize(4 * n);
build(1, values, 0, n);
}
// Adds `value` to all positions from a to b (inclusive).
void update(int a, int b, int value) {
update(1, 0, n, a, b + 1, value);
}
// returns the sum from a to b (inclusive).
int query(int a, int b) {
return query(1, 0, n, a, b + 1);
}
};
Practice problems