키타마사법(Kitamasa Method, きたまさ法)

서론

키타마사법(Kitamasa Method, きたまさ法)은 $A_N = c_1A_{N-1} + c_2A_{N-2} + \cdots + c_KA_{N-K} = \sum_{i=1}^{K} c_iA_{N-i}$와 같은 선형 점화식의 $N$번째 항을 $O(K^2 \log N)$, 더 나아가 $O(K \log K \log N)$에 계산하는 알고리즘입니다.
이 글에서는 키타마사법을 최대한 쉽게 설명하는 것에 초점을 두기 때문에 엄밀한 증명이 생략될 수 있습니다. 양해 부탁드립니다.

식의 변형

키타마사법의 목표는 이전 $K$개의 항으로 결정되는 $A_N = \sum_{i=1}^{K} c_iA_{N-i}$ 꼴의 점화식을 최초 $K$개의 항으로 결정되도록 하는 $A_N = \sum_{i=0}^{K-1} d_iA_i$ 꼴의 식을 찾는 것입니다. 이 식을 만족시키는 수열 $d_i$를 $T(N, K)$ 시간에 찾을 수 있다면, 단순히 $\sum_{i=0}^{K-1} d_iA_i$를 계산하면 되므로 점화식의 $N$번째 항을 $O(T(N, K) + K)$ 시간에 계산할 수 있습니다.

이러한 수열 $d_i$는 항상 존재하며, 적절한 식을 대입해 $A_{N-1}, A_{N-2}, \cdots$를 소거하는 방식으로 직접 찾을 수 있습니다. 그 예시로, BOJ 1003번 피보나치 수열 문제는 $c_1 = c_2 = 1$일 때(피보나치 수열) $d_i$를 찾는 문제입니다.

$c_1 = 2, c_2 = 1$이고 $N = 5$일 때의 $d_i$를 직접 구해보며 어떻게 하면 효율적으로 계산할 수 있을지 고민해봅시다.

  • $A_5 = 2A_4 + A_3$
  • $A_5 = 2(2A_3 + A_2) + A_3 = 5A_3 + 2A_2$
  • $A_5 = 5(2A_2 + A_1) + 2A_2 = 12A_2 + 5A_1$
  • $A_5 = 12(2A_1 + A_0) + 5A_1 = 29A_1 + 12A_0$
  • $d_0 = 12, d_1 = 29$가 됩니다.

사실 이 과정은 양변에 $A_x - c_1A_{x-1} - c_2A_{x-2} - \cdots = 0$의 상수배를 빼는 과정이라고 생각해볼 수 있습니다. 위 예시에서는 $A_x - 2A_{x-1} - A_{x-2} = 0$의 상수배를 빼게 되겠지요.

  • $A_5 = 2A_4 + A_3$, 양변에서 $2(A_4 - 2A_3 - A_2) = 0$을 빼면
  • $A_5 = 5A_3 + 2A_2$, 양변에서 $5(A_3 - 2A_2 - A_1) = 0$을 빼면
  • $A_5 = 12A_2 + 5A_1$, 양변에서 $12(A_2 - 2A_1 - A_0) = 0$을 빼면
  • $A_5 = 29A_1 + 12A_0$

이 과정은 $x^N$을 $f(x) = x^K - c_1x^{K-1} - c_2x^{K-2} - \cdots - c_Kx^0 = x^K - \sum_{i=1}^{K}c_ix^{K-i}$로 나눈 나머지를 구하는 과정과 완벽하게 동일합니다. 위 예시는 사실 $x^5$를 $x^2-2x-1$로 나눈 나머지를 구하는 과정이었고, 실제로 $x^5 = (x^3 + 2x^2 + 5x + 12)(x^2 - 2x - 1) + 29x + 12$입니다.

우리는 이제 $x^N \mod f(x)$를 빠르게 계산하는 방법을 알면 됩니다.

분할 정복을 이용한 거듭 제곱

키타마사법을 사용해야하는 문제는 대부분 $N$이 매우 크기 때문에 $x^N$을 그대로 다룰 수 없습니다. 우리는 어떤 수의 거듭 제곱은 $O(\log N)$에 구할 수 있으니, 이 방법을 응용해봅시다.

$x^N$은 $x^1, x^2, x^4, x^8, \cdots$들의 곱으로 표현할 수 있고, 총 $O(\log N)$번의 다항식 곱셈을 요구하게 됩니다. $x^N \mod f(x)$를 구하는 것이 목표이기 때문에, $(x^1 \mod f(x)), (x^2 \mod f(x)), (x^4 \mod f(x))$ 들의 곱을 $f(x)$로 나눈 나머지를 구하면 됩니다. $mod\ f(x)$를 취하면 차수가 $K$ 미만이기 때문에, 차수가 $K$ 미만인 두 다항식끼리의 곱셈과 나눗셈을 각각 $O(\log N)$번 수행하게 됩니다.

아래와 같이 구현하면, 키타마사법의 시간 복잡도는 Mul함수와 Div함수의 시간 복잡도에 따라 결정됩니다. 구체적으로, Mul함수와 Div함수의 시간 복잡도가 $P(K)$일 때 키타마사법의 시간 복잡도는 $O(P(K) \log N)$이 됩니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
using ll = long long;
using poly = vector<ll>;
ll kitamasa(poly c, poly a, ll n){
    poly d = {1}; // result
    poly xn = {0, 1}; // xn = x^1, x^2, x^4, ...
    poly f(c.size()+1); // f(x) = x^K - \sum c_{i}x^{i}
    f.back() = 1;
    for(int i=0; i<c.size(); i++) f[i] = -c[i];
    while(n){
        if(n & 1) d = Div(Mul(d, xn), f);
        n >>= 1; xn = Div(Mul(xn, xn), f);
    }

    ll ret = 0;
    for(int i=0; i<a.size(); i++) ret += a[i] * d[i];
    return ret;
}

$O(K^2 \log N)$ 키타마사법

다항식의 곱셈과 나눗셈을 Naive하게 구현하면 $O(K^2)$ 시간이 걸리므로 이때 키타마사법의 시간 복잡도는 $O(K^2 \log N)$이 됩니다. 아래 코드는 BOJ 11444번 피보나치 수 6 문제의 정답 코드입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using poly = vector<ll>;

constexpr int MOD = 1e9+7;

int Mod(ll x){
    return (x %= MOD) < 0 ? x + MOD : x;
}

poly Mul(const poly &a, const poly &b){
    poly ret(a.size() + b.size() - 1);
    for(int i=0; i<a.size(); i++) for(int j=0; j<b.size(); j++) {
        ret[i+j] = (ret[i+j] + a[i] * b[j]) % MOD;
    }
    return ret;
}

poly Div(const poly &a, const poly &b){
    poly ret(all(a));
    for(int i=ret.size()-1; i>=b.size()-1; i--) for(int j=0; j<b.size(); j++) {
        ret[i+j-b.size()+1] = Mod(ret[i+j-b.size()+1] - ret[i] * b[j]);
    }
    ret.resize(b.size()-1);
    return ret;
}

/// kitamasa: A_{n} = \sum c_{i}A_{n-i} = \sum d_{i}A_{i}
/// given A, c, n, get d, A_{n} in O(K^2 \log N)
ll kitamasa(poly c, poly a, ll n){
    poly d = {1}; // result
    poly xn = {0, 1}; // shift = x^1, x^2, x^4, ...
    poly f(c.size()+1); // f(x) = x^K - \sum c_{i}x^{i}
    f.back() = 1;
    for(int i=0; i<c.size(); i++) f[i] = Mod(-c[i]);
    while(n){
        if(n & 1) d = Div(Mul(d, xn), f);
        n >>= 1; xn = Div(Mul(xn, xn), f);
    }

    ll ret = 0;
    for(int i=0; i<a.size(); i++) ret = Mod(ret + a[i] * d[i]);
    return ret;
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    poly rec = {1, 1};
    poly dp = {0, 1};
    ll N; cin >> N;

    cout << kitamasa(rec, dp, N);
}

$O(K \log K \log N)$ 키타마사법

$O(K^2 \log N)$을 $O(K \log K \log N)$로 최적화시키는 방법은 당연히 다항식의 곱셈과 나눗셈을 $O(K \log K)$에 수행하는 것입니다.

FFT를 이용해 다항식의 곱셈을 $O(K \log K)$에 수행하는 방법은 제 블로그에 올라와있고, 비슷하게 FFT를 이용해 다항식의 나눗셈을 $O(K \log K)$에 수행하는 방법은 삼성 소프트웨어 멤버십 블로그에 잘 설명되어 있습니다.

NTT를 이용해 키타마사법을 $O(K \log K \log N)$에 수행하는 코드를 첨부하고 글을 마무리하겠습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
template<int M>
struct MINT{
    int v;
    MINT() : v(0) {}
    MINT(ll val){
        v = (-M <= val && val < M) ? val : val % M;
        if(v < 0) v += M;
    }

    friend istream& operator >> (istream &is, MINT &a) { ll t; is >> t; a = MINT(t); return is; }
    friend ostream& operator << (ostream &os, const MINT &a) { return os << a.v; }
    friend bool operator == (const MINT &a, const MINT &b) { return a.v == b.v; }
    friend bool operator != (const MINT &a, const MINT &b) { return a.v != b.v; }
    friend MINT pw(MINT a, ll b){
        MINT ret= 1;
        while(b){
            if(b & 1) ret *= a;
            b >>= 1; a *= a;
        }
        return ret;
    }
    friend MINT inv(const MINT a) { return pw(a, M-2); }
    MINT operator - () const { return MINT(-v); }
    MINT& operator += (const MINT m) { if((v += m.v) >= M) v -= M; return *this; }
    MINT& operator -= (const MINT m) { if((v -= m.v) < 0) v += M; return *this; }
    MINT& operator *= (const MINT m) { v = (ll)v*m.v%M; return *this; }
    MINT& operator /= (const MINT m) { *this *= inv(m); return *this; }
    friend MINT operator + (MINT a, MINT b) { a += b; return a; }
    friend MINT operator - (MINT a, MINT b) { a -= b; return a; }
    friend MINT operator * (MINT a, MINT b) { a *= b; return a; }
    friend MINT operator / (MINT a, MINT b) { a /= b; return a; }
    operator int32_t() const { return v; }
    operator int64_t() const { return v; }
};

namespace fft{
    template<int W, int M>
    static void NTT(vector<MINT<M>> &f, bool inv_fft = false){
        using T = MINT<M>;
        int N = f.size();
        vector<T> root(N >> 1);
        for(int i=1, j=0; i<N; i++){
            int bit = N >> 1;
            while(j >= bit) j -= bit, bit >>= 1;
            j += bit;
            if(i < j) swap(f[i], f[j]);
        }
        T ang = pw(T(W), (M-1)/N); if(inv_fft) ang = inv(ang);
        root[0] = 1; for(int i=1; i<N>>1; i++) root[i] = root[i-1] * ang;
        for(int i=2; i<=N; i<<=1){
            int step = N / i;
            for(int j=0; j<N; j+=i){
                for(int k=0; k<i/2; k++){
                    T u = f[j+k], v = f[j+k+(i>>1)] * root[k*step];
                    f[j+k] = u + v;
                    f[j+k+(i>>1)] = u - v;
                }
            }
        }
        if(inv_fft){
            T rev = inv(T(N));
            for(int i=0; i<N; i++) f[i] *= rev;
        }
    }
    template<int W, int M>
    vector<MINT<M>> multiply_ntt(vector<MINT<M>> a, vector<MINT<M>> b){
        int N = 2; while(N < a.size() + b.size()) N <<= 1;
        a.resize(N); b.resize(N);
        NTT<W, M>(a); NTT<W, M>(b);
        for(int i=0; i<N; i++) a[i] *= b[i];
        NTT<W, M>(a, true);
        return a;
    }
}

template<int W, int M>
struct PolyMod{
    using T = MINT<M>;
    vector<T> a;

    // constructor
    PolyMod(){}
    PolyMod(T a0) : a(1, a0) { normalize(); }
    PolyMod(const vector<T> a) : a(a) { normalize(); }

    // method from vector<T>
    int size() const { return a.size(); }
    int deg() const { return a.size() - 1; }
    void normalize(){ while(a.size() && a.back() == T(0)) a.pop_back(); }
    T operator [] (int idx) const { return a[idx]; }
    typename vector<T>::const_iterator begin() const { return a.begin(); }
    typename vector<T>::const_iterator end() const { return a.end(); }
    void push_back(const T val) { a.push_back(val); }
    void pop_back() { a.pop_back(); }

    // basic manipulation
    PolyMod reversed() const {
        vector<T> b = a;
        reverse(b.begin(), b.end());
        return b;
    }
    PolyMod trim(int n) const {
        return vector<T>(a.begin(), a.begin() + min(n, size()));
    }
    PolyMod inv(int n){
        PolyMod q(T(1) / a[0]);
        for(int i=1; i<n; i<<=1){
            PolyMod p = PolyMod(2) - q * trim(i * 2);
            q = (p * q).trim(i * 2);
        }
        return q.trim(n);
    }

    // operation with scala value
    PolyMod operator *= (const T x){
        for(auto &i : a) i *= x;
        normalize();
        return *this;
    }
    PolyMod operator /= (const T x){
        return *this *= (T(1) / T(x));
    }

    // operation with poly
    PolyMod operator += (const PolyMod &b){
        a.resize(max(size(), b.size()));
        for(int i=0; i<b.size(); i++) a[i] += b.a[i];
        normalize();
        return *this;
    }
    PolyMod operator -= (const PolyMod &b){
        a.resize(max(size(), b.size()));
        for(int i=0; i<b.size(); i++) a[i] -= b.a[i];
        normalize();
        return *this;
    }
    PolyMod operator *= (const PolyMod &b){
        *this = fft::multiply_ntt<W, M>(a, b.a);
        normalize();
        return *this;
    }
    PolyMod operator /= (const PolyMod &b){
        if(deg() < b.deg()) return *this = PolyMod();
        int sz = deg() - b.deg() + 1;
        PolyMod ra = reversed().trim(sz), rb = b.reversed().trim(sz).inv(sz);
        *this = (ra * rb).trim(sz);
        for(int i=sz-size(); i; i--) push_back(T(0));
        reverse(all(a));
        normalize();
        return *this;
    }
    PolyMod operator %= (const PolyMod &b){
        if(deg() < b.deg()) return *this;
        PolyMod tmp = *this; tmp /= b; tmp *= b;
        *this -= tmp;
        normalize();
        return *this;
    }

    // operator
    PolyMod operator * (const T x) const { return PolyMod(*this) *= x; }
    PolyMod operator / (const T x) const { return PolyMod(*this) /= x; }
    PolyMod operator + (const PolyMod &b) const { return PolyMod(*this) += b; }
    PolyMod operator - (const PolyMod &b) const { return PolyMod(*this) -= b; }
    PolyMod operator * (const PolyMod &b) const { return PolyMod(*this) *= b; }
    PolyMod operator / (const PolyMod &b) const { return PolyMod(*this) /= b; }
    PolyMod operator % (const PolyMod &b) const { return PolyMod(*this) %= b; }
};

constexpr int W = 3, MOD = 104857601;
using mint = MINT<MOD>;
using poly = PolyMod<W, MOD>;

mint kitamasa(poly c, poly a, ll n){
    poly d = vector<mint>{1};
    poly xn = vector<mint>{0, 1};
    poly f;
    for(int i=0; i<c.size(); i++) f.push_back(-c[i]);
    f.push_back(1);
    while(n){
        if(n & 1) d = d * xn % f;
        n >>= 1; xn = xn * xn % f;
    }
    mint ret = 0;
    for(int i=0; i<=a.deg(); i++) ret += a[i] * d[i];
    return ret;
}

참고 자료