带权并查集

概念

带权并查集(Weighted Union-Find)是并查集的一种扩展,它在标准的并查集操作基础上引入了一个权重或者秩的概念,使得每个节点不仅代表一个集合,还记录了该集合的一些额外信息。通常,这个额外信息是与集合中的元素相关联的权重或者秩。

维护距离

很多带权并查集的题目都是通过维护距离这个额外的元素来求解。
现在设 $d[i]$ 数组为第 $i$ 个点到他父节点的距离,一开始,每个点自己就是自己的父节点,所以初始化为0。
img
从图中可以看到,边上面是带有值的,这也是为什么叫做带权并查集。
从这张图我们可以列出这4个点对应的 $d[i]$

  • d[1] = 4
  • d[2] = 0
  • d[6] = 5
  • d[5] = 6
    而往往我们需要的是到根节点的距离,所以我们通过路径压缩,让这个点直接指向根节点,在运行的同时,把新的 $d[i]$ 给算出来。经过路径压缩后,$d[5] = d[5] + d[6] = 11$ ,所以说带权并查集是需要用到路径压缩的。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
int p[N], d[N]; //p[]存储每个点的祖宗节点, d[x]存储x到p[x]的距离

// 返回x的祖宗节点
int find(int x)
{
if (p[x] != x)
{
int u = find(p[x]);
d[x] += d[p[x]];
p[x] = u;
}
return p[x];
}

void init(int n){
// 初始化,假定节点编号是1~n
for (int i = 1; i <= n; i ++ )
{
p[i] = i;
d[i] = 0;
}
}


int px = find(x);
int py = find(y);
// 合并a和b所在的两个集合:
p[py] = find(y);
d[py] = d[x] - d[y] + s;

s即为一个偏移量,而偏移量如何去求出来,需要根据题意。

在合并方面,如果是普通的并查集,我们可以选择任意一点直接合并到另一点。但是在带权并查集,我们还需要维护一个额外的信息,这里就是距离。

普通并查集:
img

带权并查集:
img

要是想让$x,y$这两个集合连通我们需要处理好$d[py]$ 这是一个未知量,也是我们需要的。

重点就是如何去计算,这个时候偏移量的出现就为之重要了。 通过图可以发现 $y$到$px$的距离应该是恒定的,所以有: $$s + d[x] = d[y] + d[py]$$
所以要想计算$d[py]$就很简单了。
$$d[py] = d[x] - d[y] + s$$
但是我们只是算出并改变了 $d[py]$ 为什么这样就OK了呢?可以查看我们的$find()$ 函数,只要调用了$find(i)$,那么 $i$ 节点的$d[i]$ 就一定是正确的,并且$i$还会指向根节点,因为我们使用了路径压缩。

例题:

食物链

食物链 这道题题目大概是给了三种关系,然后让你判断合法性,最后输出不合法的操作数量。
img
箭头代表 x->y = x会被y吃。

这题用带权并查集来做的时候,$d[i]$ 数组代表距离的同时,里面的值的不同也有特殊含义,所以我们通过$MOD3$来控制距离的大小。

  • 0代表,与根节点同一族
  • 1代表,该节点可以吃掉根节点一族
  • 2代表,该节点被根节点一族吃

是同类的话,偏移量为0。是不同的话偏移量为1,正负取决于你的x和y谁合并到谁身上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <deque>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <numeric>
#include <iomanip>

#define _fio \
ios_base::sync_with_stdio(0); \
cin.tie(0); \
cout.tie(0);

#define pb(e) push_back(e)
#define all(x) (x).begin(), (x).end()
#define allr(x) (x).rbegin(), (x).rend()
#define endl '\n'

using namespace std;
using i64 = long long;
using PII = pair<int, int>;
const int INF = 0x3f3f3f3f;

struct DSU
{
std::vector<int> p, siz, d;
int res;

DSU() {}
DSU(int n) { init(n); }

void init(int n)
{
p.resize(n);
std::iota(p.begin(), p.end(), 0);
siz.assign(n, 1);
d.assign(n, 0);
}

int find(int x)
{
if (p[x] != x)
{
int u = find(p[x]);
d[x] += d[p[x]]; // 按情况修改
d[x] %= 3;
p[x] = u;
}
return p[x];
}

bool same(int x, int y) { return find(x) == find(y); }

bool merge(int x, int y, int s)
{
int px = find(x);
int py = find(y);
if (px == py) return false;
siz[px] += siz[py];
p[py] = px;
d[py] = (d[x] - d[y] + s) % 3; // s为偏移量,按情况修改
return true;
}

int size(int x) { return siz[find(x)]; }
};

int main()
{
int n, m, d, x, y;
cin >> n >> m;
DSU dsu(n + 1);

while (m--)
{
cin >> d >> x >> y;

if (x > n || y > n)
{
dsu.res += 1;
continue;
}

int px = dsu.find(x);
int py = dsu.find(y);

if (d == 1)
{
if (dsu.same(x, y) && (dsu.d[x] % 3 != dsu.d[y] % 3))
dsu.res++;
else if (!dsu.same(x, y))
dsu.merge(x, y, 3);
}
else
{
if (dsu.same(x, y) && (dsu.d[x] % 3 != (dsu.d[y] + 1) % 3))
dsu.res++;
else if (!dsu.same(x, y))
dsu.merge(x, y, -1 + 3);
}
}

cout << dsu.res << endl;

return 0;
}

注意关系的判断和偏移量的计算就行了。

封装模板

改自并查集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
struct DSUV
{
std::vector<int> p, siz, d;
DSUV() {}
DSUV(int n) { init(n); }
void init(int n)
{
p.resize(n);
std::iota(p.begin(), p.end(), 0);
siz.assign(n, 1);
d.assign(n, 0);
}
int find(int x)
{
if (p[x] != x)
{
int u = find(p[x]);
d[x] += d[p[x]];//情况而定
p[x] = u;
}
return p[x];
}
bool same(int x, int y) { return find(x) == find(y); }
bool merge(int x, int y, int s)
{
int px = find(x);
int py = find(y);
if (px == py)
return false;
siz[px] += siz[py];
p[py] = px;
d[py] = (d[x] - d[y] + s);//情况而定
return true;
}
int size(int x) { return siz[find(x)]; }
};


带权并查集
http://pikachuxpf.github.io/posts/dbe7e802/
作者
Pikachu_fpx
发布于
2024年1月13日
许可协议