Tác giả:
Reviewer:
MITM là một kỹ thuật tìm kiếm được sử dụng khi đầu vào nhỏ nhưng không đủ nhỏ để có thể quay lui (backtracking). Trước khi tiếp tục về kỹ thuật MITM, chúng ta cần xem xét bài toán đơn giản sau:
Cho mảng có phần tử. Hỏi có bao nhiêu cách chọn tập con sao cho tổng bằng .
Giới hạn:
Ta duyệt qua tất cả các tập con có thể có rồi cập nhật kết quả bằng đệ quy (một cách khác để duyệt qua các tập con là sử dụng bitmask <insert bài bitmask>).
long long cnt;
// quay lui đến phần tử thứ i
// trong i-1 phần tử đầu, tổng các t[i] trong tập là sum
void Try(int i, int sum) {
// tiếp tục quay lui với tập có sum > x là không cần thiết
if (sum > x) return;
if (i > n) {
if (sum == x) ++cnt;
}
else {
// không lấy phần tử thứ i
Try(i + 1, sum);
// lấy phần tử thứ i
Try(i + 1, sum + t[i]);
}
}
long long solve() {
cnt = 0;
Try(1, 0);
return cnt;
}
Thuật toán trên có độ phức tạp thời gian là , không đủ nhanh để giải bài toán bởi vì khá lớn. Do đó, ta cần tìm một phương án tối ưu hơn.
Kỹ thuật MITM được mô tả như sau:
#include <bits/stdc++.h>
using namespace std;
const int N = 40 + 2;
int n, x;
int t[N];
vector<int> A, B;
void TryX(int i, int sum) {
if (sum > x) return;
if (i > n / 2) A.push_back(sum);
else {
TryX(i + 1, sum);
TryX(i + 1, sum + t[i]);
}
}
void TryY(int i, int sum) {
if (sum > x) return;
if (i > n) B.push_back(sum);
else {
TryY(i + 1, sum);
TryY(i + 1, sum + t[i]);
}
}
int main() {
cin >> n >> x;
for (int i = 1; i <= n; ++i) cin >> t[i];
// Quay lui 2 tập X và Y
TryX(1, 0);
TryY(n / 2 + 1, 0);
// Sắp xếp mảng B
sort(B.begin(), B.end());
// Lặp qua mảng A và tìm kiếm nhị phân:
// - Đếm số lượng phần tử trong B có giá trị bằng x - A[i]
long long cnt = 0;
for (int sum : A) {
cnt += upper_bound(B.begin(), B.end(), x - sum)
- lower_bound(B.begin(), B.end(), x - sum);
}
cout << cnt << '\n';
}
// Quay lui 2 tập X và Y
TryX(1, 0);
TryY(n / 2 + 1, 0);
// Sắp xếp mảng A và B
sort(A.begin(), A.end(), greater<int>());
sort(B.begin(), B.end());
// Sử dụng kỹ thuật 2 con trỏ
long long cnt = 0;
for (int i = 0, j1 = 0, j2 = 0; i < A.size(); ++i) {
int s = x - A[i]; // cần đếm lượng B[j] thoả B[j] = s
while (j1 < B.size() && B[j1] < s) ++j1;
while (j2 < B.size() && B[j2] <= s) ++j2;
cnt += j2 - j1;
}
cout << cnt << '\n';
Có cục vàng, mỗi cục vàng có trọng lượng và giá trị . Bạn có một cái túi có tải trọng tối đa là . Hỏi tổng giá trị vàng lớn nhất có thể thu được mà không làm rách túi.
Giới hạn:
Áp dụng MITM, ta tách cục vàng thành tập và , tập chứa cục vàng đầu tiên và tập chứa phần còn lại.
Bây giờ, quay lui cho với mỗi tập và , ta được tập và chứa các cặp (tổng trọng lượng , tổng giá trị ) của các tập con.
Để kết hợp tập và , ta cần giải quyết bài toán con: Với mỗi cặp của tập , ta cần tìm một cặp trong tập sao cho và là lớn nhất.
Để giải bài toán con này, gợi ý là sắp xếp lại mảng theo thứ tự tăng dần của và đặt (phần này có thể tính nhanh bằng mảng cộng dồn).
#include <bits/stdc++.h>
using namespace std;
const int N = 40 + 2, MaxSize = (1 << 20) + 10;
int n, m;
int w[N], v[N];
long long sumVA[MaxSize];
int sumWA[MaxSize];
int sizeA;
pair<int, long long> B[MaxSize];
int sizeB;
int sumWB[MaxSize];
long long maxSumVB[MaxSize];
void TryX(int i, int sumW, long long sumV) {
if (sumW > m) return;
if (i > n / 2) {
++sizeA;
sumWA[sizeA] = sumW;
sumVA[sizeA] = sumV;
return;
}
TryX(i + 1, sumW, sumV);
TryX(i + 1, sumW + w[i], sumV + v[i]);
}
void TryY(int i, int sumW, long long sumV) {
if (sumW > m) return;
if (i > n) {
++sizeB;
B[sizeB].first = sumW;
B[sizeB].second = sumV;
return;
}
TryY(i + 1, sumW, sumV);
TryY(i + 1, sumW + w[i], sumV + v[i]);
}
int main() {
cin >> n >> m;
for (int i = 1; i <= n; ++i) cin >> w[i] >> v[i];
TryX(1, 0, 0);
TryY(n / 2 + 1, 0, 0);
sort(B + 1, B + sizeB + 1);
for (int i = 1; i <= sizeB; ++i) {
sumWB[i] = B[i].first;
maxSumVB[i] = max(maxSumVB[i - 1], B[i].second);
}
long long maxValue = 0;
for (int i = 1; i <= sizeA; ++i) {
int j = upper_bound(sumWB + 1, sumWB + sizeB + 1, m - sumWA[i]) - sumWB - 1;
maxValue = max(maxValue, sumVA[i] + maxSumVB[j]);
}
cout << maxValue;
}
Cho mảng gồm số nguyên, đếm số lượng dãy con tăng có độ dài .
Giới hạn:
Đặt ứng với một dãy con tăng có độ dài .
Theo cách làm ngây thơ, với mỗi , ta đếm số cặp thoả mãn trong , tổng độ phức tạp thời gian sẽ là .
Ta có thể ứng dụng "middle" như sau: thay vì xét đầu tiên, ta xét đầu tiên.
Với mỗi , ta đếm số lượng thoả và thoả trong , tổng độ phức tạp thời gian lúc này sẽ là .
for (int j = 0; j < n; ++j) {
int smaller = 0, bigger = 0;
for (int i = 0; i < j; ++i) {
if (a[i] < a[j]) ++smaller;
}
for (int k = j + 1; k < n; ++k) {
if (a[k] > a[j]) ++bigger;
}
answer += smaller * bigger;
}
Cho mảng gồm số nguyên và số nguyên . Ta cần tìm vị trí phân biệt sao cho tổng giá trị ở vị trí đó bằng .
Giới hạn:
Đặt là vị trí thoả mãn .
Thuật toán ngây thơ của bài toán này là sử dụng vòng lặp lồng nhau với độ phức tạp .
for (int i = 1; i <= n; ++i)
for (int j = i + 1; j <= n; ++j)
for (int k = j + 1; k <= n; ++k)
for (int l = k + 1; l <= n; ++l)
if (a[i] + a[j] + a[k] + a[l] == x) { ... }
Ta có nhận xét: trong vòng lặp thứ (biến ), ta đang giải bài toán: tìm vị trí phân biệt lớn hơn sao cho tổng giá trị của vị trí đó bằng .
Ta có thể giải bài toán này trước bằng cách:
Sử dụng std::map
để lưu cặp vị trí của mỗi giá trị tổng.
#include <bits/stdc++.h>
using namespace std;
const int N = 1000 + 3;
int n, x;
int a[N];
int main() {
cin >> n >> x;
for (int i = 1; i <= n; ++i) cin >> a[i];
// preprocess
map<int, pair<int, int>> mp;
for (int i = 1; i <= n; ++i)
for (int j = i + 1; j <= n; ++j)
mp[a[i] + a[j]] = make_pair(i, j);
// solve
for (int i = 1; i <= n; ++i)
for (int j = i + 1; j <= n; ++j) {
// thay vì 2 vòng for, bây giờ ta chỉ cần
// truy vấn trên std::map
int X = x - a[i] - a[j];
if (mp.count(X)) {
pair<int, int> arr = mp[X];
if (j < arr.first) {
cout << i << ' ' << j << ' ' << arr.first << ' ' << arr.second;
return 0;
}
}
}
cout << "IMPOSSIBLE";
}
Độ phức tạp tiền xử lý:
Độ phức tạp truy vấn:
Có truy vấn, vì thế, tổng độ phức tạp thời gian là:
Cho đồ thị có hướng đỉnh () và bậc ngoài của mỗi đỉnh không quá . Tất cả đỉnh đều được tô màu. Tìm một đường đi độ dài sao cho đỉnh trong đường đi có màu phân biệt. Nếu có nhiều cách chọn, in ra bất kỳ, ngược lại, in ra "fail".
Giới hạn thời gian là rất lớn (12 giây).
Tương tự Bài toán 2, ta có thể ứng dụng "middle" như sau:
#include <bits/stdc++.h>
using namespace std;
const int N = 100 + 2;
int n;
int c[N];
vector<int> g[2][N];
int cntbit[16];
void init() {
for (int msk = 1; msk < 16; ++msk) {
cntbit[msk] = cntbit[msk >> 1] + (msk & 1);
}
}
void readData() {
cin >> n;
map<string, int> artist;
for (int i = 1; i <= n; ++i) {
string name;
cin >> name;
c[i] = artist.count(name) ? artist[name] : (artist[name] = artist.size() + 1);
int k, to;
cin >> k;
while (k--) {
cin >> to;
g[0][i].push_back(to);
g[1][to].push_back(i);
}
}
}
vector<int> getAns(vector<int> res) {
set<int> s;
for (int u : res) s.insert(c[u]);
for (int v0 : g[0][res.back()]) {
if (s.count(c[v0])) continue;
s.insert(c[v0]);
for (int v1 : g[0][v0]) {
if (s.count(c[v1])) continue;
s.insert(c[v1]);
for (int v2 : g[0][v1]) {
if (s.count(c[v2])) continue;
s.insert(c[v2]);
for (int v3 : g[0][v2]) {
if (s.count(c[v3])) continue;
res.push_back(v0);
res.push_back(v1);
res.push_back(v2);
res.push_back(v3);
return res;
}
s.erase(c[v2]);
}
s.erase(c[v1]);
}
s.erase(c[v0]);
}
return {};
}
int cnt[N * N * N * N];
int getHash(const array<int, 4> &a, int msk) {
int hsh = 0;
for (int i = 0; i < 4; ++i) {
if (msk >> i & 1) hsh = hsh * N + c[a[i]];
}
return hsh;
}
vector<int> solve(int u) {
vector<int> sav(1, 0);
for (int v0 : g[0][u]) {
if (c[v0] == c[u]) continue;
for (int v1 : g[0][v0]) {
if (c[v1] == c[v0] || c[v1] == c[u]) continue;
for (int v2 : g[0][v1]) {
if (c[v2] == c[v1] || c[v2] == c[v0] || c[v2] == c[u]) continue;
for (int v3 : g[0][v2]) {
if (c[v3] == c[v2] || c[v3] == c[v1] || c[v3] == c[v0] || c[v3] == c[u]) continue;
array<int, 4> a = { c[v0], c[v1], c[v2], c[v3] };
sort(a.begin(), a.end());
for (int msk = 0; msk < 16; ++msk) {
int hsh = getHash(a, msk);
++cnt[hsh];
sav.push_back(hsh);
}
}
}
}
}
for (int v0 : g[1][u]) {
if (c[v0] == c[u]) continue;
for (int v1 : g[1][v0]) {
if (c[v1] == c[v0] || c[v1] == c[u]) continue;
for (int v2 : g[1][v1]) {
if (c[v2] == c[v1] || c[v2] == c[v0] || c[v2] == c[u]) continue;
for (int v3 : g[1][v2]) {
if (c[v3] == c[v2] || c[v3] == c[v1] || c[v3] == c[v0] || c[v3] == c[u]) continue;
array<int, 4> a = { c[v0], c[v1], c[v2], c[v3] };
sort(a.begin(), a.end());
int sum = 0;
for (int msk = 0; msk < 16; ++msk) {
int hsh = getHash(a, msk);
sum += cnt[hsh] * (cntbit[msk] & 1 ? -1 : 1);
}
if (sum > 0) {
vector<int> res = { v3, v2, v1, v0, u };
return getAns(res);
}
}
}
}
}
for (int x : sav) cnt[x] = 0;
return vector<int>();
}
void solve() {
for (int i = 1; i <= n; ++i) {
vector<int> vec = solve(i);
if (!vec.empty()) {
for (int x : vec) cout << x << ' ';
return;
}
}
cout << "fail";
}
int main() {
init();
readData();
solve();
}