Tác giả:
Reviewer:
Note: Để hiểu được bài viết này, các bạn cần có kiến thức về:
Xét bài toán sau đây:
Cho một dãy gồm phần tử, và ta cần làm loại thao tác sau đây:
Lưu ý rằng chúng ta cần làm bài toán này online - cần đưa ra đáp án ngay sau khi được hỏi.
Nếu bỏ cụm từ "sau thao tác thứ " thì đây là một bài ứng dụng segment tree cơ bản. Tuy nhiên bài này yêu cầu chúng ta lưu lại lịch sử những lần update.
Chúng ta có thể sử dụng Persistent Data Structures để giải bài toán này. Lưu ý rằng đây không phải là một cấu trúc dữ liệu mới, mà là một kĩ thuật giúp chúng ta lưu lại thông tin về lịch sử của cấu trúc dữ liệu.
Ví dụ như dưới đây chúng ta vẫn dùng Segment Tree nhưng có thêm một số thay đổi để có thể lưu lại những trạng thái cũ của dãy trước khi thay đổi, và thêm vào trạng thái mới nhất của dãy.
Ta xử lý bài toán trên như sau:
Ta thấy rằng chúng ta chỉ thay đổi giá trị của một phần tử trong Segment Tree lần, và do đó độ phức tạp của thuật toán này là
const long long infty = 1e18 + 7;
const int MAXN = 3e5 + 5;
int n;
int a[MAXN];
// val: giá trị của node i
// le, ri: các node con của node i
// con số 20 ở đây đại diện cho log2(N), và những mảng có kích thước N * 20 là những mảng cần N * log2(N) memory
long long val[MAXN * 20];
int le[MAXN * 20], ri[MAXN * 20];
// rood_idx[i]: node gốc được tạo ra ở query thứ i
int root_idx[MAXN];
// cur[id]: node chứa version mới nhất của node id trong segment tree ban đầu
int cur[MAXN * 20];
// tol_node lưu tổng số node hiện tại
int tolnode;
long long init(int id, int l, int r){
tolnode++;
// version ban đầu của node id là node id :)
cur[id] = id;
if(l == r){
val[id] = a[l];
return a[l];
}
le[id] = (id << 1), ri[id] = (id << 1 | 1);
int mid = (l + r) >> 1;
// tính tổng đoạn (l, r)
long long sum = init(id << 1, l, mid) + init(id << 1 | 1, mid + 1, r);
val[id] = sum;
return sum;
}
void upd(int id, int l, int r, int pos, long long v){
if(l == r){
tolnode++;
val[tolnode] = v;
return;
}
tolnode++;
int nw_id = tolnode;
// trạng thái mới nhất của node id trở thành nw_id
cur[id] = nw_id;
int mid = (l + r) >> 1;
// Trong TH này, ta sẽ tạo ra một node mới là con bên trái của node hiện tại, và node con bên phải sẽ là cur[id << 1 | 1]
if(pos <= mid){
upd(id << 1, l, mid, pos, v);
// lưu ý rằng index của node con sẽ được tạo ra ngay sau node này nên index của nó là nw_id + 1
le[nw_id] = nw_id + 1;
ri[nw_id] = cur[id << 1 | 1];
val[nw_id] = val[le[nw_id]] + val[ri[nw_id]];
}
// Trong TH này, ta sẽ tạo ra một node mới là con bên phải của node hiện tại, và node con bên trái sẽ là cur[id << 1]
else{
upd(id << 1 | 1, mid + 1, r, pos, v);
le[nw_id] = cur[id << 1];
ri[nw_id] = nw_id + 1;
val[nw_id] = val[le[nw_id]] + val[ri[nw_id]];
}
}
// gọi k là thời điểm hỏi
// id ở đây, thay vì là id của segment tree ban đầu, thì là cur[id] ở thời điểm k
long long get(int id, int l, int r, int L, int R){
if(R < l || r < L) return 0;
if(l >= L && r <= R) return val[id];
int mid = (l + r) >> 1;
return get(le[id], l, mid, L, R) + get(ri[id], mid + 1, r, L, R);
}
// lưu thời điểm của lần thay đổi hiện tại
int cnt_que;
void update(int p, long long v){
cnt_que++;
root_idx[cnt_que] = tolnode + 1;
upd(1, 1, n, p, v);
}
// ta bắt đầu tại root_idx[k]
long long ans(int l, int r, int k){
return get(root_idx[k], 1, n, l, r);
}
Xét phiên bản 2D của bài toán trên
Cho một bảng gồm phần tử, và ta cần làm loại thao tác sau đây:
Việc áp dụng phương pháp tạo node mới như trên với BIT là khá khó khăn, và thay vào đó chúng ta có thể làm như dưới đây:
Với mỗi node trong BIT ta sẽ lưu lại lịch sử những lần thay đổi của node này.
// y1 là tên một biến trong C++ nên mình sử dụng tạm cách này
// tuy nhiên không khuyến khích mọi người làm theo
#define y1 y11
const long long infty = 1e18 + 7;
const int MAXN = 3e3 + 5;
// kích thước bảng
int N, M;
int a[MAXN][MAXN];
vector<pair<int, long long>> updates[MAXN][MAXN];
// tính giá trị ban đầu
long long pref[MAXN][MAXN];
int range[MAXN];
void init(){
// xác định khoảng mà vị trí i quản lý trong BIT
for(int i = 1; i <= max(N, M); i++){
range[i] = (i & (i - 1)) + 1;
}
// ta dùng prefix sum để có được ĐPT O(n * m)
for(int i = 1; i <= N; i++){
for(int j = 1; j <= M; j++) pref[i][j] = pref[i - 1][j] + pref[i][j - 1] - pref[i - 1][j - 1];
}
for(int i = 1; i <= N; i++){
for(int j = 1; j <= M; j++){
int x1 = range[i], y1 = range[j], x2 = i, y2 = j;
long long startval = pref[x2][y2] - pref[x1 - 1][y2] - pref[x2][y1 - 1] + pref[x1 - 1][y1 - 1];
updates[i][j].push_back({0, startval});
}
}
}
void upd(int x, int y, int k, long long val){
for(int i = x; i <= N; i += i & -i){
for(int j = y; j <= M; j += j & -j){
// giá trị gần nhất của phần tử (i, j)
long long lst = updates[i][j].back().second;
lst += val;
updates[i][j].push_back(make_pair(k, val));
}
}
}
// hàm get() sẽ tính giá trị của hcn con (1, 1), (x, y)
long long get(int x, int y, int k){
long long ans = 0;
for(int i = x; i; i -= i & -i){
for(int j = y; j; j -= j & -j){
// ta lower_bound cặp (k, oo) để chắc chắn ra được vị trí nhỏ nhất có thời gian > k, và sau đó ta -1 để ra vị trí cần tìm.
int pos = lower_bound(updates[i][j].begin(), updates[i][j].end(), make_pair(k, infty)) - updates[i][j].begin() - 1;
ans += updates[i][j][pos].second;
}
}
return ans;
}
long long ans(int x1, int y1, int x2, int y2, int k){
return get(x2, y2, k) - get(x1 - 1, y2, k) - get(x2, y1 - 1, k) + get(x1 - 1, y1 - 1, k);
}
Độ phức tạp của thuật toán này là:
Về mặt thời gian:
Về mặt không gian:
Cách làm này tuy độ phức tạp cao hơn nhưng lại tổng quát hơn khi có thể dùng cho BIT, IT và nhiều cấu trúc dữ liệu khác.