Skip to content

Range Modify Query (RMQ)

zkw Tree

Source: 統計的力量

cpp
// modify in single point, sum a range
struct ZKWTree {
    vector<int64_t> data;
    size_t base;

    ZKWTree(size_t n) {
        base = 1 << __lg(n + 5) + 1;
        data = vector<int64_t>(base << 1, 0);
    }

    // x in [0, n)
    void update(size_t x, int64_t v) {
        ++x;  // 0-base to 1-base
        x += base;
        data[x] = v;
        for (x >>= 1; x; x >>= 1) {
            data[x] = data[x << 1] + data[(x << 1) | 1];
        }
    }

    // [l, r]
    // l, r in [0, n)
    int64_t query(size_t l, size_t r) const {
        ++l, ++r;  // 0-base to 1-base
        int64_t ans = 0;
        for (l += base - 1, r += base + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
            if (~l & 1) {
                ans += data[l ^ 1];
            }
            if (r & 1) {
                ans += data[r ^ 1];
            }
        }
        return ans;
    }
};

Dynamic Segment Tree

  • Range Modify and Range Query
  • Dynamic create nodes
cpp
class SegmentTree {
private:
    struct Node;

    struct Node {
        Node(int l, int r): l(l), r(r) {
        }

        void makeLeft() {
            if (left) {
                return;
            }
            int m = (l + r) >> 1;
            left = new Node(l, m);
        }

        void makeRight() {
            if (right) {
                return;
            }
            int m = (l + r) >> 1;
            right = new Node(m, r);
        }

        void update() {
            int v = 0;
            if (left) {
                v = max(v, left->maxSumDelta);
            }
            if (right) {
                v = max(v, right->maxSumDelta);
            }
            maxSumDelta = v + delta;
        }

        Node* left = nullptr;
        Node* right = nullptr;

        // [l, r)
        int l;
        int r;

        // delta of the whole range
        int delta = 0;

        // max of sum of delta 
        int maxSumDelta = 0;
    };

public:
    // create [0, R)
    SegmentTree(int R): root(new Node(0, R)) {
    }

    SegmentTree(const SegmentTree&) = delete;
    SegmentTree& operator=(const SegmentTree&) = delete;

    ~SegmentTree() {
        if (!root) {
            return;
        }
        nodeDelete(root);
    }

    // add v to [l, r)
    void add(int l, int r, int v) {
        nodeAdd(l, r, v, root);
    }

    // return max in [l, r)
    int get(int l, int r) const {
        return nodeGet(l, r, root);
    }

private:
    Node* root;

    void nodeDelete(Node* node) {
        if (node->right) {
            nodeDelete(node->right);
        }
        if (node->left) {
            nodeDelete(node->left);
        }
        delete node;
    }

    void nodeAdd(int l, int r, int v, Node* node) {
        // assert:  node->l <= l < r <= node->r

        if (node->l == l && node->r == r) {
            node->delta += v;
            node->update();
            // cout << "add: [" << l << "," << r << "): " << v << endl;
            return;
        }

        int m = (node->l + node->r) >> 1;
        if (r <= m) {
            node->makeLeft();
            nodeAdd(l, r, v, node->left);
            node->update();
            return;
        } else if (m <= l) {
            node->makeRight();
            nodeAdd(l, r, v, node->right);
            node->update();
            return;
        }
        
        // l < m && m < r
        node->makeLeft();
        node->makeRight();
        nodeAdd(l, m, v, node->left);
        nodeAdd(m, r, v, node->right);
        node->update();
    }

    int nodeGet(int l, int r, Node* node) const {
        // assert:  node->l <= l < r <= node->r
        if (!node) {
            return 0;
        }

        if (node->l == l && node->r == r) {
            // cout << "get: [" << l << "," << r << "): " << node->maxSumDelta << endl;
            return node->maxSumDelta;
        }

        int m = (node->l + node->r) >> 1;
        if (r <= m) {
            return node->delta + nodeGet(l, r, node->left);
        } else if (m <= l) {
            return node->delta + nodeGet(l, r, node->right);
        }

        return node->delta + max(nodeGet(l, m, node->left), nodeGet(m, r, node->right));
    }
};

Changelog

Just observe 👀