프로그래밍/알고리즘

[알고리즘] 제곱근 분할법 (Sqrt decomposition)

riroan 2023. 3. 11. 23:55

Q : N개의 정수 배열에서 a~b까지 합을 구하려면 어떻게 하나요?

A : for문을 돌면서 더해요! $O(N)$

Q : 그럼 Q개의 쿼리형식으로 들어온다면요?

A : 누적합배열을 사용해요! $O(N+Q)$

Q : 그럼 여기에 업데이트도 있다면 있다면 어떻게 할까요?

A : 세그먼트 트리를 쓰면 되죠! $O(Q \log N)$ 

 

위의 모든 답변은 실제로 정답이고 더 어려운 알고리즘이 하위 문제를 해결할 수 있다.

예를들어 누적합으로 1번 문제를 해결할 수 있고 세그먼트 트리로 1, 2번 문제를 해결할 수 있다.

하지만 위 문제들을 제곱근 분할법(a.k.a. 루트질)을 사용해서 해결할 수도 있다!

물론 쉬운 알고리즘을 사용할 수 있다면 그걸 사용하는게 좋겠지만 연습용으로 풀기 좋다.

어떻게 해결할 지 알아보자.

 

버킷

위의 질문에서 나온 수열이 $1,2,3,4,5,6,7,8,9,10$이라고 하자.

이걸 임의의 사이즈의 버킷으로 묶은 후 합을 저장할 것이다.

같은 색은 같은 버킷에 있다는 뜻이고 같은 버킷에 있는 수들의 합은 B값으로 따로 관리해준다.

이제 2~8까지의 합을 구해보면 3~7은 2~8에 속하기때문에 직접 더하는 대신 B2값을 사용하면 된다.

그 외의 범위는 직접 더해주면 2~8까지 덧셈을 6번하는 것 대신 2+B2+8을 함으로써 3번만 해도 된다.

 

자 그럼 중요한건 버킷의 사이즈를 정하는 방법이다.

우선 위에서는 아무렇게나 정했지만 최악의 경우 4~8, 4~9, 5~8, 6~9, 7~9 이런식으로 들어오면 직접 더하는거랑 차이가 없다.

이 경우 버킷 $i$의 크기를 $b_i$라고 하면 $O(b_2+b_3)$만큼의 계산량을 사용하게 된다.

또한 2~9까지라면 $O(b_1+b_3)$, 2~6까지라면 $O(b_1+b_2)$가 될 것이다.

만약 위의 경우처럼 $O(b_2+b_3)$가 걸리는 쿼리가 $A$개, $O(b_1+b_3)$인 쿼리가 $B$개, $O(b_1+b_2)$가 걸리는 쿼리가 $C$개라면 모든 쿼리를 처리했을 때 $O((B+C)b_1+(A+C)b_2+(A+B)b_3)$이 걸리게 되고 우리는 이를 최소화해야한다. (아주 최악의 최악 상황만 모아놓은 것이다.)

 

산술평균 기하평균에 의해 위의 시간복잡도는 $(B+C)b_1+(A+C)b_2+(A+B)b_3 \ge 3\sqrt[3]{(B+C)(A+C)(A+B)b_1b_2b_3}$가 되고 $b_1 = b_2 = b_3$일 때 최소가 됨을 알 수 있다.

버킷의 개수가 3일 때를 예로 들었지만 일반적인 상황에서도 유사하게 귀납적으로 적용할 수 있다.

 

첫 번째 포인트

모든 버킷의 크기는 같아야 한다!

 

제곱근 분할법

 

자 그럼 버킷의 크기가 같아야함을 알았으니 모든 버킷의 크기를 $B$라고 하자. $N$이 $B$로 나누어떨어지지 않을 경우 뒤에 남는 원소들이 있을 수 있는데 이는 맨 뒷 버킷에 추가하거나 버킷을 따로 만들어도 상관 없다. (개인적으로 후자를 선호한다.)

임의로 $B = 4$로 정했고 $N \equiv 2 \mod B$이므로 따로 버킷을 만들었다.

이제 2~9 쿼리가 반복적으로 들어오면 $O(B + B + \frac{N-2B}{B})$를 반복해야 할 것이다.

$\frac{N-2B}{B}$은 중간에 버킷들을 더하는 횟수이다. 현재는 B2 한번만 더하지만 $B$의 크기가 달라지면 중간에 있는 버킷의 개수가 달라지므로 그 만큼 더해야 될 것이다.

$\frac{N-2B}{B} = \frac{N}{B}-2$이므로 시간복잡도는 $O(2B+\frac{N}{B}-2)=O(B+\frac{N}{B})$가 되고 이 또한 최적화 해야한다.

 

또 다시 산술평균 기하평균을 쓰면 $B+\frac{N}{B} \ge 2\sqrt{N}$이 되고 $B = \frac{N}{B}$일 때 최소가 된다.

즉 버킷의 사이즈는 $B = \sqrt{N}$일 때 최소가 된다!!

이렇게 "루트개수"로 버킷을 "나누기" 때문에 제곱근 분할법(sqrt decomposition)이라고 불린다.

 

두 번째 포인트

버킷의 크기는 수열의 크기의 제곱근이어야 한다!

 

제곱근 분할법으로 문제풀기

자 그럼 이제 뭔지 알았으니 이것을 이용하여 2042 구간 합 구하기를 풀어보자!

크기가 $B = \lfloor \sqrt{10} \rfloor = 3$인 버킷으로 구성했다.

 

1. 업데이트 쿼리

query) 8번째 인덱스의 수를 11로 변경하시오

이런식으로 매 쿼리마다 $O(1)$에 해결할 수 있다.

 

2. 구간 합 쿼리

query) 2 ~ 8 합을 구하시오

이런식으로 구하게 되는데 하나의 버킷은 많아봐야 $\sqrt{N}$개 만큼 있고 B2같이 버킷전체의 값을 사용하는 경우도 많아봐야 $\sqrt{N}$개이므로 쿼리마다 $O(\sqrt{N})$을 사용하게 된다!

 

문제를 시간복잡도 $O(Q\sqrt{N})$으로 해결할 수 있으므로 충분히 해결할 수 있다.

물론 세그먼트 트리보다 느리니까 제곱근 분할법 연습용으로만 사용하자.

 

정답코드

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

signed main()
{
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);
    int n, m, k;
    cin >> n >> m >> k;
    int b = sqrt(n); // bucket size
    vector<int> arr(n);
    for (auto &i : arr)
        cin >> i;
    vector<int> brr; // bucket
    for (int i = 0; i < n; i += b)
    {
        int s = 0;
        for (int j = i; j < min(i + b, n); j++)
            s += arr[j];
        brr.push_back(s);
    }
    // O(qn**0.5)
    for (int query = 0; query < m + k; query++)
    {
        int op, x, y;
        cin >> op >> x >> y;
        --x;
        if (op == 1)
        {
            // 버킷 업데이트 O(1)
            int where = x / b;
            brr[where] -= arr[x];
            arr[x] = y;
            brr[where] += arr[x];
        }
        else if (op == 2)
        {
            --y;
            int left = x / b;
            int right = y / b;
            int ans = 0;
            for (int i = left + 1; i < right; i++)
                ans += brr[i];
            if (left == right)
            {
                // 같은 버킷이면 선형합 O(n**0.5)
                for (int i = x; i <= y; i++)
                    ans += arr[i];
            }
            else
            {
                // 다른 버킷이면 왼쪽 버킷, 오른쪽 버킷 따로 더함 O(n**0.5)
                for (int i = x; i < (left + 1) * b; i++)
                    ans += arr[i];
                for (int i = right * b; i <= y; i++)
                    ans += arr[i];
            }
            cout << ans << endl;
        }
    }
    return 0;
}

 

하지만 출력 쿼리가 최솟값을 출력하는 것이라면?

버킷을 스플레이같은 bbst나 pbds로 관리해서 $O(Q\sqrt{N}\log{N})$에 해결할 수 있긴 하다.

... 그냥 세그먼트 트리를 사용하자...

 

 

연습문제

14427 수열과 쿼리 15

2042 구간 합 구하기

18436 수열과 쿼리 17

 

모두 세그먼트 트리로 쉽게 풀리지만 제곱근 분할법을 연습해보자!