프로그래밍/알고리즘

[알고리즘] smaller to larger를 사용해서 분리집합을 구현하자!

riroan 2023. 4. 13. 00:46

이 글에는 지극히 주관적인 내용이 많이 들어있습니다.

 

분리집합을 구현할 때 보통 가리키는 포인터를 옮기는 방식으로 구현하곤 한다. 하지만 이런 방법은 경로압축도 해야 하고 개인적으로 비직관적이었다. 사실 나는 재귀적인 코드와 작동 과정이 아직 완벽히 이해가 되지 않는다.

그래서 상대적으로 직관적인 smaller to larger(a.k.a 작은거 큰거)를 사용한 분리집합 구현을 소개하고자 한다. (물론 기존 방식보다 느리고 공간도 많이 사용한다.)

smaller to larger라는게 애초에 집합을 합치기 위한 연산을 최적화하기 위한 기법이기 때문에 분리집합에도 사용할 수 있다.

 

기초적인 구조

class DisjointSet
{
public:
    vector<vector<int>> p;
    vector<int> ix;
    int n;
}

우선 멤버변수를 두개 사용한다. p는 합치기 연산을 수행할 벡터이고 ix는 i가 p[ix[i]]에 들어있음을 알려준다.

// 생성자

DisjointSet(int _n) : n(_n)
{
    p.resize(n);
    ix.resize(n);

    iota(ix.begin(), ix.end(), 0);
    for (int i = 0; i < n; i++)
        p[i].push_back(i);
}

p에 각각 자신의 인덱스를 넣어주고 ix도 마찬가지로 초기화한다. ix[0] = 0이므로 0은 p[0]에 들어있다는 의미가 된다.

size 연산

inline int size(int x)
{
    return p[ix[x]].size();
}

x가 들어있는 벡터의 크기를 리턴하면 된다. 사이즈 연산을 위해 p배열을 사용했다. 사실 이 부분도 size배열로 해결할 수 있지만 둘 다 $O(n)$의 공간을 사용하기 때문에 직관성을 높였다.

get, same 연산

inline int get(int x)
{
    return ix[x];
}

inline bool same(int x, int y)
{
    return get(x) == get(y);
}

same함수를 사용하면 같은 p에 있는지 확인할 수 있다. ix가 없었다면 p를 직접 돌며 $O(n)$이 걸렸을 것이다.

merge 연산

void merge(int x, int y)
{
    int from = ix[x], to = ix[y];

    if (from == to)
        return;
    # smaller to larger
    if (p[from].size() > p[to].size())
        swap(from, to);
    int start = p[to].size();
    for (auto i : p[from])
    {
        p[to].push_back(i);
        ix[i] = to;
    }

    p[from].clear();
}

이제 핵심인 merge연산이다. 더 작은 벡터에서 큰 벡터로 합침으로써 시간상 이득을 볼 수 있다. smaller to larger로 인해 많아야 $O(n \log n)$을 사용하게 된다. start는 rollback에서 사용될 변수로 몇번째 인덱스부터 합쳐질 수들이 들어오는지를 기록한다.

rollback 연산

분리집합에서 롤백연산을 하고 싶을 수 있다. 그럴땐 움직임을 기록하는 변수를 하나 정의하고 merge할때마다 기록하면 된다.

vector<tuple<int, int, int>> history;

void merge(int x, int y)
{
    ...
    history.push_back({from, to, start});
}
void rollback()
{
    if (history.empty())
        return;
    auto [from, to, start] = history.back();
    history.pop_back();
    while (p[to].size() > start)
    {
        int x = p[to].back();
        p[to].pop_back();
        ix[x] = from;
        p[from].push_back(x);
    }
}

rollback연산을 사용하면 바로 전 merge연산을 되돌린다. rollback 연산을 하면 원래의 p로 돌아가지는 않지만 merge하기 전 요소로 정확히 나눠준다. (원래의 p로 돌아가게 하려면 deque를 사용한다.)

최종 코드

class DisjointSet
{
public:
    vector<vector<int>> p;
    vector<int> ix;
    vector<tuple<int, int, int>> history;

    int n;
    DisjointSet(int _n) : n(_n)
    {
        p.resize(n);
        ix.resize(n);

        iota(ix.begin(), ix.end(), 0);
        for (int i = 0; i < n; i++)
            p[i].push_back(i);
    }

    inline int get(int x)
    {
        return ix[x];
    }

    inline bool same(int x, int y)
    {
        return get(x) == get(y);
    }

    inline int size(int x)
    {
        return p[ix[x]].size();
    }

    void merge(int x, int y)
    {
        int from = ix[x], to = ix[y];

        if (from == to)
            return;
        if (p[from].size() > p[to].size())
            swap(from, to);
        int start = p[to].size();
        for (auto i : p[from])
        {
            p[to].push_back(i);
            ix[i] = to;
        }
        p[from].clear();

        history.push_back({from, to, start});
    }

    void rollback()
    {
        if (history.empty())
            return;
        auto [from, to, start] = history.back();
        history.pop_back();
        while (p[to].size() > start)
        {
            int x = p[to].back();
            p[to].pop_back();
            ix[x] = from;
            p[from].push_back(x);
        }
    }
};

 

개인적으로 이 방법으로 구현한 분리집합이 훨씬 직관적이었다. 다만 size를 2차원배열로 연산하기때문에 약간의 공간복잡도가 추가되지만 vector에 들어있는 변수만큼의 공간만 추가로 사용해서 큰 문제는 없다. 그리고 기존 분리집합 merge연산의 속도는 $O(\alpha (n))$이지만 위 방식은 $O(n)$이다. 하지만 모든 집합을 연결하는데는 기존 $O(n \alpha (n))$, 위 방식 $O(n \log (n))$이 보장되므로 문제 해결하는데 큰 제약은 없다. (다만 rollback과 merge가 반복적으로 많이 나오면 실패할수도?)

 

템플릿이 없고 분리집합 코드를 외우지 못했을 때 쓰면 좋을 것 같다.