并查集与最小生成树

 

并查集

并查集

并查集是用于管理元素所属集合的数据结构, 实现为一个森林, 其中每棵树表示一个集合, 树中节点表示对应集合中元素

节点

对于独立两个节点$A、B$

graph BT;
    A(A)
    B(B)

定义 $parent[A] = B$, 表示节点$A$父节点是节点$B$

graph BT;
    a(A)-->b(B)
// 使用map存储节点间关系
template<class T, class T>
std::map<T, T> parent;

初始化

初始时每个元素都位于一个单独集合, 其父节点均为自身

graph BT;
    subgraph 集合2
    direction BT
        B(B)
    end

    subgraph 集合1
    direction BT
        A(A)
    end
template<class T>
void Init(T x) {
    parent[x] = x;
}

查询

查询操作用于查询某个元素所属集合, 用于判断两个元素是否属于同一集合(同根节点)

graph BT;
    subgraph 集合2
    direction BT
        d(D)-->y(Y)
    end

    subgraph 集合1
    direction BT
        a(A)-->c(C)
        b(B)-->x(X)
        c(C)-->x(X)
    end

节点$A、B、C$根节点均为$X$, 故节点$A、B、C$属于同一集合

节点$D$根节点为节点$Y$, 故节点$A$节点$D$不属于同一集合

template<class T>
T Find(T x) {
    // 若x父节点非它本身, 则继续查找
    while (parent[x] != x) {
        x = parent[x];
    }
    return x;
}

路径压缩

查询过程中所经过每个元素都属于该集合, 可将每个元直接连到根节点以加快后续查询

graph BT;
    subgraph 路径压缩后
    direction BT
        X1(X)

        a1(A)-->X1
        b1(B)-->X1
        c1(C)-->X1
        d1(D)-->X1
    end

    subgraph 路径压缩前
    direction BT
        X(X)

        a(A)-->c(C)
        b(B)-->c(C)
        c(C)-->X
        d(D)-->X
    end
template<class T>
T Find(T x) {
    if (parent[x] != x) {
        parent[x] = Find(parent[x]);
    }
    return parent[x];
}

合并

合并两个元素所属集合, 即将一个集合根节点连到另一集合根节点

graph BT;
    subgraph 集合2
    direction TB
        e(E)-->y(Y)
        f(F)-->y(Y)
    end

    subgraph 集合1
    direction TB
        a(A)-->c(C)
        b(B)-->c(C)
        c(C)-->x(X)
        d(D)-->x(X)
    end

    x(X)==>y(Y)
graph LR;
    subgraph 集合3
    direction BT
        a(A)-->c(C)
        b(B)-->c(C)
        c(C)-->x(X)
        d(D)-->x(X)
        e(E)-->y(Y)
        f(F)-->y(Y)

        x -->y
    end

将节点$X$的父节点设为节点$Y$, 合并两个集合

template<class T>
void Unions(const T x, const T y) {
    T fx = Find(x);
    T fy = Find(y);
    if (fx != fy) {
        parent[fx] = fy;
    }
}

按秩合并(Union by Rank)

并查集树结构中, 树高度会影响查找操作效率

秩(rank)通常是树高度估计, 按秩合并时, 将秩小树连接到秩大树上, 从而避免较大树高度增加

若两棵树秩相同, 那么任选其中一个树根节点作为新根节点, 并将其秩加1

template<class T>
void Unions(const T x, const T y) {
    T fx = Find(x);
    T fy = Find(y);
    if (fx != fy) {
        if (rank[fx] > rank[fy]) {
            parent[fy] = fx;
        }
        else if(rank[fx] > rank[fy]) {
            parent[fx] = fy;
        }
        else {
            parent[fx] = fy;
            rank[fx]++;
        }
    }
}
graph BT;
    X(X rank:0)
    Y(Y rank:0)
    Z(Z rank:0)

合并节点$X、Y$, 其秩一致, 任选节点$Y$作为新根节点

graph BT;
    X(X rank:0)
    Y(Y rank:1)
    Z(Z rank:0)
    X --> Y

合并节点$Z、Y$, 其节点$Z$秩小于节点$Y$秩, 即将节点$Z$合并到节点$Y$

graph BT;
    X(X rank:0)
    Y(Y rank:2)
    Z(Z rank:0)
    X --> Y
    Z --> Y

代码

#include <iostream>
#include <algorithm>
#include <utility>
#include <vector>
#include <map>
#include <set>

template<class NodeType = std::string>
struct Line {
    NodeType mStartNode;
    NodeType mEndNode;
    double   mWeight;
    bool     mIsSelect;

    Line(NodeType s, NodeType e, double w) :
        mStartNode(std::move(s)), mEndNode(std::move(e)), mWeight(w), mIsSelect(false) {}
};

template<class NodeType = std::string>
class DisjointSetUnion {
public:
    DisjointSetUnion() = default;
    DisjointSetUnion(std::vector<Line<NodeType>>& lines) {
        for (const auto& line : lines) {
            mNodes.insert(line.mStartNode);
            mNodes.insert(line.mEndNode);
        }

        for (const auto& node : mNodes) {
            mParent[node] = node;
            mRank[node] = 0;
        }
    };

    NodeType Find(NodeType x) {
        if (mParent[x] != x) {
            mParent[x] = Find(mParent[x]);
        }
        return mParent[x];
    }

    void Unions(NodeType x, NodeType y) {
        NodeType fx = Find(x);
        NodeType fy = Find(y);
        if (fx != fy) {
            if (mRank[fx] > mRank[fy]) {
                mParent[fy] = fx;
            }
            else if (mRank[fx] > mRank[fy]) {
                mParent[fx] = fy;
            }
            else {
                mParent[fx] = fy;
                mRank[fx]++;
            }
        }
    }

private:
    std::set<NodeType>           mNodes;
    std::map<NodeType, NodeType> mParent;
    std::map<NodeType, int>      mRank;
};

最小生成树

kruskal法

完整代码路径

  • 将所有边按权值大小顺序排列

  • 对于任意两个节点,若不在同个并查集内(不会形成闭环), 选择该边, 并和合并两个节点

graph LR;
    a(A)<--1--->b(B)
    a(A)<--7--->d(d)
    a(A)<--9--->c(C)
    b(B)<--2--->e(E)
    b(B)<--5--->d(D)
    c(C)<--4--->d(D)
    c(C)<--2--->g(G)
    d(D)<--4--->e(E)
    d(D)<--1--->f(F)
    d(D)<--6--->g(G)
    e(E)<--7--->f(F)
    f(F)<--3--->g(G)
graph LR;
    a(A)<==1===>b(B)
    a(A)<--7--->d(d)
    a(A)<--9--->c(C)
    b(B)<==2===>e(E)
    b(B)<--5--->d(D)
    c(C)<--4--->d(D)
    c(C)<==2===>g(G)
    d(D)<==4===>e(E)
    d(D)<==1===>f(F)
    d(D)<--6--->g(G)
    e(E)<--7--->f(F)
    f(F)<==3===>g(G)
template<class NodeType = std::string>
class MinimumSpanningTree {
public:
    MinimumSpanningTree(std::vector<Line<NodeType>>& lines) {
        mUnions = DisjointSetUnion<NodeType>(lines);

        std::sort(lines.begin(), lines.end(), [=](const Line<NodeType>& e1, const Line<NodeType>& e2) { return e1.mWeight < e2.mWeight; });
        mLines = std::move(lines);
        mMSTValue = GetKruskal();
    }

    double GetKruskal() {
        double sum = 0;
        for (auto& line : mLines) {
            if (mUnions.Find(line.mStartNode) != mUnions.Find(line.mEndNode)) {
                sum += line.mWeight;
                line.mIsSelect = true;
                mUnions.Unions(line.mStartNode, line.mEndNode);
            }
        }
        return sum;
    }

    void PrintMSTResult() const {
        std::cout << "The minimum spanning tree mWeight = " << mMSTValue << std::endl;

        for (const auto& line : mLines) {
            if (line.mIsSelect) {
                std::cout << "Select Line: " << line.mStartNode << "-" << line.mEndNode << std::endl;
            }
        }
    }

private:
    DisjointSetUnion<NodeType>  mUnions;
    std::vector<Line<NodeType>> mLines;
    double                      mMSTValue;
};
int main() {
    std::vector<Line<>> lines = {
        Line<>("A", "B", 1), Line<>("A", "C", 9), Line<>("A", "D", 7),
        Line<>("B", "D", 5), Line<>("B", "E", 2), Line<>("E", "D", 4),
        Line<>("E", "F", 7), Line<>("F", "D", 1), Line<>("F", "G", 3),
        Line<>("G", "D", 6), Line<>("G", "C", 2), Line<>("C", "D", 4),
    };

    MinimumSpanningTree<> mst(lines);
    mst.PrintMSTResult();
    return 0;
}

运行结果

The minimum spanning tree value = 13
Select Line: A-B
Select Line: F-D
Select Line: B-E
Select Line: G-C
Select Line: F-G
Select Line: E-D