概念
带权并查集(Weighted Union-Find)是并查集的一种扩展,它在标准的并查集操作基础上引入了一个权重或者秩的概念,使得每个节点不仅代表一个集合,还记录了该集合的一些额外信息。通常,这个额外信息是与集合中的元素相关联的权重或者秩。
维护距离
很多带权并查集的题目都是通过维护距离这个额外的元素来求解。
现在设 $d[i]$ 数组为第 $i$ 个点到他父节点的距离,一开始,每个点自己就是自己的父节点,所以初始化为0。
![img](https://pic.imgdb.cn/item/65aa6ff1871b83018a1c0283.png)
从图中可以看到,边上面是带有值的,这也是为什么叫做带权并查集。
从这张图我们可以列出这4个点对应的 $d[i]$
- d[1] = 4
- d[2] = 0
- d[6] = 5
- d[5] = 6
而往往我们需要的是到根节点的距离,所以我们通过路径压缩,让这个点直接指向根节点,在运行的同时,把新的 $d[i]$ 给算出来。经过路径压缩后,$d[5] = d[5] + d[6] = 11$ ,所以说带权并查集是需要用到路径压缩的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| int p[N], d[N];
int find(int x) { if (p[x] != x) { int u = find(p[x]); d[x] += d[p[x]]; p[x] = u; } return p[x]; }
void init(int n){ for (int i = 1; i <= n; i ++ ) { p[i] = i; d[i] = 0; } }
int px = find(x); int py = find(y); p[py] = find(y); d[py] = d[x] - d[y] + s;
|
s即为一个偏移量,而偏移量如何去求出来,需要根据题意。
在合并方面,如果是普通的并查集,我们可以选择任意一点直接合并到另一点。但是在带权并查集,我们还需要维护一个额外的信息,这里就是距离。
普通并查集:
![img](https://pic.imgdb.cn/item/65aa7007871b83018a1c6935.png)
带权并查集:
![img](https://pic.imgdb.cn/item/65aa7019871b83018a1cb7d1.png)
要是想让$x,y$这两个集合连通我们需要处理好$d[py]$ 这是一个未知量,也是我们需要的。
重点就是如何去计算,这个时候偏移量的出现就为之重要了。 通过图可以发现 $y$到$px$的距离应该是恒定的,所以有: $$s + d[x] = d[y] + d[py]$$
所以要想计算$d[py]$就很简单了。
$$d[py] = d[x] - d[y] + s$$
但是我们只是算出并改变了 $d[py]$ 为什么这样就OK了呢?可以查看我们的$find()$ 函数,只要调用了$find(i)$,那么 $i$ 节点的$d[i]$ 就一定是正确的,并且$i$还会指向根节点,因为我们使用了路径压缩。
例题:
食物链
食物链 这道题题目大概是给了三种关系,然后让你判断合法性,最后输出不合法的操作数量。
![img](https://pic.imgdb.cn/item/65aa7034871b83018a1d379e.png)
箭头代表 x->y = x会被y吃。
这题用带权并查集来做的时候,$d[i]$ 数组代表距离的同时,里面的值的不同也有特殊含义,所以我们通过$MOD3$来控制距离的大小。
- 0代表,与根节点同一族
- 1代表,该节点可以吃掉根节点一族
- 2代表,该节点被根节点一族吃
是同类的话,偏移量为0。是不同的话偏移量为1,正负取决于你的x和y谁合并到谁身上。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
| #include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> #include <queue> #include <deque> #include <stack> #include <unordered_map> #include <unordered_set> #include <numeric> #include <iomanip>
#define _fio \ ios_base::sync_with_stdio(0); \ cin.tie(0); \ cout.tie(0);
#define pb(e) push_back(e) #define all(x) (x).begin(), (x).end() #define allr(x) (x).rbegin(), (x).rend() #define endl '\n'
using namespace std; using i64 = long long; using PII = pair<int, int>; const int INF = 0x3f3f3f3f;
struct DSU { std::vector<int> p, siz, d; int res;
DSU() {} DSU(int n) { init(n); }
void init(int n) { p.resize(n); std::iota(p.begin(), p.end(), 0); siz.assign(n, 1); d.assign(n, 0); }
int find(int x) { if (p[x] != x) { int u = find(p[x]); d[x] += d[p[x]]; d[x] %= 3; p[x] = u; } return p[x]; }
bool same(int x, int y) { return find(x) == find(y); }
bool merge(int x, int y, int s) { int px = find(x); int py = find(y); if (px == py) return false; siz[px] += siz[py]; p[py] = px; d[py] = (d[x] - d[y] + s) % 3; return true; }
int size(int x) { return siz[find(x)]; } };
int main() { int n, m, d, x, y; cin >> n >> m; DSU dsu(n + 1);
while (m--) { cin >> d >> x >> y;
if (x > n || y > n) { dsu.res += 1; continue; }
int px = dsu.find(x); int py = dsu.find(y);
if (d == 1) { if (dsu.same(x, y) && (dsu.d[x] % 3 != dsu.d[y] % 3)) dsu.res++; else if (!dsu.same(x, y)) dsu.merge(x, y, 3); } else { if (dsu.same(x, y) && (dsu.d[x] % 3 != (dsu.d[y] + 1) % 3)) dsu.res++; else if (!dsu.same(x, y)) dsu.merge(x, y, -1 + 3); } }
cout << dsu.res << endl;
return 0; }
|
注意关系的判断和偏移量的计算就行了。
封装模板
改自并查集
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| struct DSUV { std::vector<int> p, siz, d; DSUV() {} DSUV(int n) { init(n); } void init(int n) { p.resize(n); std::iota(p.begin(), p.end(), 0); siz.assign(n, 1); d.assign(n, 0); } int find(int x) { if (p[x] != x) { int u = find(p[x]); d[x] += d[p[x]]; p[x] = u; } return p[x]; } bool same(int x, int y) { return find(x) == find(y); } bool merge(int x, int y, int s) { int px = find(x); int py = find(y); if (px == py) return false; siz[px] += siz[py]; p[py] = px; d[py] = (d[x] - d[y] + s); return true; } int size(int x) { return siz[find(x)]; } };
|