Range Modify Query (RMQ)
zkw Tree
Source: 統計的力量
cpp
// modify in single point, sum a range
struct ZKWTree {
vector<int64_t> data;
size_t base;
ZKWTree(size_t n) {
base = 1 << __lg(n + 5) + 1;
data = vector<int64_t>(base << 1, 0);
}
// x in [0, n)
void update(size_t x, int64_t v) {
++x; // 0-base to 1-base
x += base;
data[x] = v;
for (x >>= 1; x; x >>= 1) {
data[x] = data[x << 1] + data[(x << 1) | 1];
}
}
// [l, r]
// l, r in [0, n)
int64_t query(size_t l, size_t r) const {
++l, ++r; // 0-base to 1-base
int64_t ans = 0;
for (l += base - 1, r += base + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
if (~l & 1) {
ans += data[l ^ 1];
}
if (r & 1) {
ans += data[r ^ 1];
}
}
return ans;
}
};
Dynamic Segment Tree
- Range Modify and Range Query
- Dynamic create nodes
cpp
class SegmentTree {
private:
struct Node;
struct Node {
Node(int l, int r): l(l), r(r) {
}
void makeLeft() {
if (left) {
return;
}
int m = (l + r) >> 1;
left = new Node(l, m);
}
void makeRight() {
if (right) {
return;
}
int m = (l + r) >> 1;
right = new Node(m, r);
}
void update() {
int v = 0;
if (left) {
v = max(v, left->maxSumDelta);
}
if (right) {
v = max(v, right->maxSumDelta);
}
maxSumDelta = v + delta;
}
Node* left = nullptr;
Node* right = nullptr;
// [l, r)
int l;
int r;
// delta of the whole range
int delta = 0;
// max of sum of delta
int maxSumDelta = 0;
};
public:
// create [0, R)
SegmentTree(int R): root(new Node(0, R)) {
}
SegmentTree(const SegmentTree&) = delete;
SegmentTree& operator=(const SegmentTree&) = delete;
~SegmentTree() {
if (!root) {
return;
}
nodeDelete(root);
}
// add v to [l, r)
void add(int l, int r, int v) {
nodeAdd(l, r, v, root);
}
// return max in [l, r)
int get(int l, int r) const {
return nodeGet(l, r, root);
}
private:
Node* root;
void nodeDelete(Node* node) {
if (node->right) {
nodeDelete(node->right);
}
if (node->left) {
nodeDelete(node->left);
}
delete node;
}
void nodeAdd(int l, int r, int v, Node* node) {
// assert: node->l <= l < r <= node->r
if (node->l == l && node->r == r) {
node->delta += v;
node->update();
// cout << "add: [" << l << "," << r << "): " << v << endl;
return;
}
int m = (node->l + node->r) >> 1;
if (r <= m) {
node->makeLeft();
nodeAdd(l, r, v, node->left);
node->update();
return;
} else if (m <= l) {
node->makeRight();
nodeAdd(l, r, v, node->right);
node->update();
return;
}
// l < m && m < r
node->makeLeft();
node->makeRight();
nodeAdd(l, m, v, node->left);
nodeAdd(m, r, v, node->right);
node->update();
}
int nodeGet(int l, int r, Node* node) const {
// assert: node->l <= l < r <= node->r
if (!node) {
return 0;
}
if (node->l == l && node->r == r) {
// cout << "get: [" << l << "," << r << "): " << node->maxSumDelta << endl;
return node->maxSumDelta;
}
int m = (node->l + node->r) >> 1;
if (r <= m) {
return node->delta + nodeGet(l, r, node->left);
} else if (m <= l) {
return node->delta + nodeGet(l, r, node->right);
}
return node->delta + max(nodeGet(l, m, node->left), nodeGet(m, r, node->right));
}
};