Segment Tree

Theory

Implementation

// All intervals are half-open [a, b). The ID of the first node is 1.
class SegmentTree {
private:
    vector<int> tree;
    int n;

    /*
     * node - the segment tree index
     * values - the original array
     * left - the segment tree's left limit (inclusive)
     * right - the segment tree's right limit (exclusive)
     */
    void build(int node, vector<int>& values, int left, int right) {
        if (left + 1 == right) {
            tree[node] = values[left];
            return;
        }

        int leftNode = 2 * node;
        int rightNode = 2 * node + 1;
        int mid = left + (right - left) / 2;
        build(leftNode, values, left, mid);
        build(rightNode, values, mid, right);
        tree[node] = tree[leftNode] + tree[rightNode];
    }

    /*
     * node - the segment tree index
     * left - the segment tree's left limit (inclusive)
     * right - the segment tree's right limit (exclusive)
     * index - the position in the array to update
     * value - the new value for the position
     */
    void update(int node, int left, int right, int index, int value) {
        if (left + 1 == right) {
            tree[node] = value;
            return;
        }

        int leftNode = 2 * node;
        int rightNode = 2 * node + 1;
        int mid = left + (right - left) / 2;
        if (index < mid) {
            update(leftNode, left, mid, index, value);
        } else {
            update(rightNode, mid, right, index, value);
        }
        tree[node] = tree[leftNode] + tree[rightNode];
    }

    /*
     * node - the segment tree index
     * left - the segment tree's left limit (inclusive)
     * right - the segment tree's right limit (exclusive)
     * a - the query's left limit (inclusive)
     * b - the query's right limit (exclusive)
     */
    int query(int node, int left, int right, int a, int b) {
        if (right <= a || b <= left) { return 0; }
        if (a <= left && right <= b) {
            return tree[node];
        }

        int mid = left + (right - left) / 2;

        int leftNode = 2 * node;
        int rightNode = 2 * node + 1;
        if (b <= mid) {
            return query(leftNode, left, mid, a, b);
        } else if (a >= mid) {
            return query(rightNode, mid, right, a, b);
        }
        int leftQuery = query(leftNode, left, mid, a, mid);
        int rightQuery = query(rightNode, mid, right, mid, b);

        // Merge operation. In this case, simply the sum of both subtrees, but
        // it can be something else, e.g., (min(L, R), max(L, R), etc.).
        return leftQuery + rightQuery;
    }
public:
    SegmentTree() : n(0) {}
    
    // Initializes the segment tree from an array of values.
    SegmentTree(vector<int>& values) {
        n = values.size();
        tree.resize(4 * n);
        build(1, values, 0, n);
    }

    // updates the array at `index`.
    void update(int index, int value) {
        update(1, 0, n, index, value);
    }

    // returns the sum from a to b (inclusive).
    int query(int a, int b) {
        return query(1, 0, n, a, b + 1);
    }
};

With lazy propagation

// All intervals are half-open [a, b). The ID of the first node is 1.
class SegmentTreeWithLazyPropagation {
private:
    vector<int> tree;
    vector<int> lazy;
    int n;

    /*
    * node - the segment tree index
    * values - the original array
    * left - the segment tree's left limit (inclusive)
    * right - the segment tree's right limit (exclusive)
    */
    void build(int node, vector<int>& values, int left, int right) {
        if (left + 1 == right) {
            tree[node] = values[left];
            return;
        }
        int leftNode = 2 * node;
        int rightNode = 2 * node + 1;
        int mid = left + (right - left) / 2;
        build(leftNode, values, left, mid);
        build(rightNode, values, mid, right);
        tree[node] = tree[leftNode] + tree[rightNode];
    }

    /*
     * node - the segment tree index
     * left - the segment tree's left limit (inclusive)
     * right - the segment tree's right limit (exclusive)
     * a - the query's left limit (inclusive)
     * b - the query's right limit (exclusive)
     * value - the amount to add to all the elements in the interval.
     */
    void update(int node, int left, int right, int a, int b, int value) {
        int leftNode = 2 * node;
        int rightNode = 2 * node + 1;
        if (lazy[node] != 0) {
            tree[node] += (right - left) * lazy[node];
            if (left + 1 != right) {
                lazy[leftNode] += lazy[node];
                lazy[rightNode] += lazy[node];
            }
            lazy[node] = 0;
        }

        if (left >= right || b <= left || right <= a) {
            return;
        }

        if (a <= left && right <= b) {
            tree[node] += (right - left) * value;
            if (left + 1 != right) {
                lazy[leftNode] += value;
                lazy[rightNode] += value;
            }
            return;
        }

        int mid = left + (right - left) / 2;
        update(leftNode, left, mid, a, b, value);
        update(rightNode, mid, right, a, b, value);
        tree[node] = tree[leftNode] + tree[rightNode];
    }

    /*
     * node - the segment tree index
     * left - the segment tree's left limit (inclusive)
     * right - the segment tree's right limit (exclusive)
     * a - the query's left limit (inclusive)
     * b - the query's right limit (exclusive)
     */
    int query(int node, int left, int right, int a, int b) {
        if (left >= right || b <= left || right <= a) {
            return 0;
        }
        int leftNode = 2 * node;
        int rightNode = 2 * node + 1;
        if (lazy[node] != 0) {
            tree[node] += (right - left) * lazy[node];
            if (left + 1 != right) {
                lazy[leftNode] += lazy[node];
                lazy[rightNode] += lazy[node];
            }
            lazy[node] = 0;
        }

        if (a <= left && right <= b) {
            return tree[node];
        }

        int mid = left + (right - left) / 2;
        int leftQuery = query(leftNode, left, mid, a, b);
        int rightQuery = query(rightNode, mid, right, a, b);
        return leftQuery + rightQuery;
    }

public:
    SegmentTreeWithLazyPropagation(): n(0) {}

    SegmentTreeWithLazyPropagation(vector<int>& values) {
        n = values.size();
        tree.resize(4 * n);
        lazy.resize(4 * n);
        build(1, values, 0, n);
    }

    // Adds `value` to all positions from a to b (inclusive).
    void update(int a, int b, int value) {
        update(1, 0, n, a, b + 1, value);
    }

    // returns the sum from a to b (inclusive).
    int query(int a, int b) {
        return query(1, 0, n, a, b + 1);
    }
};

Practice problems


by

Tags: