백준19499 K-transform

문제 링크

  • http://icpc.me/19499

사용 알고리즘

  • FFT
  • Kitamasa
  • Fenwick Tree

시간복잡도

  • $O(K \log K \log M)$

풀이

정확히 $T$번 연산해서 $N$이 되는 경우의 수를 $g(T, N)$라고 정의합시다. $g(T, N)$은 아래와 같이 표현할 수 있습니다.

\[g(T, N) = \begin{cases} 1 \text{, if } T = 0 \\ g(T-1, Nk) \text{, if } k \vert (N+1) \\ g(T-1, N+1) + g(T-1, Nk) \text{, otherwise} \end{cases}\]

입력으로 주어진 $k, m, mod$에 대해 $g(k, m) \mod mod$가 문제의 정답이 됩니다.

여러 $k$에 대해 로컬에서 작은 항을 계산해본 뒤 berlekamp-massey를 돌려보면 선형 점화식을 얻을 수 있습니다. $D(m)$을 문제의 정답이라고 정의하면, 점화식은 아래와 같이 나옵니다.

\[D(m) = \begin{cases}2^m \text{, if } m < k-1 \\ 2^m-1 \text{, if } m = k-1 \\ \displaystyle \sum_{i=1}^{k} D(m-i) \text{, otherwise}\end{cases}\]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
using ll = long long;
using lll = __int128_t;

int K;

int f(int T, lll N){
    if(T == 0) return 1;
    if((N+1) % K == 0) return f(T-1, N*K);
    else return (f(T-1, N+1) + f(T-1, N*K)) % MOD;
}

void Go(int _k){
    K = _k;
    vector<int> V;
    for(int i=0; i<3*K; i++) V.push_back(f(i, 1));
    auto R = berlekamp_massey(V);
    if(K < 10) cout << " ";
    cout << K << " : ";
    for(auto i : R) cout << i << " "; cout << "\n";
}

$k \leq 10^4, m \leq 10^{18}$이므로 Kitamasa와 FFT를 사용하면 $O(k \log k \log m)$에 점화식을 계산할 수 있습니다.
하지만, FFT를 사용한 다항식 나눗셈은 느린 $O(k \log k)$에 속하기 때문에 2초라는 빡빡한 시간 제한을 통과하기 어렵습니다. 최적화를 해야 합니다.

Divisor를 보면, 항상 $C_k(x) = x^k - x^{k-1} - x^{k-2} - x^{k-3} - \cdots - 1$ 꼴이라는 것을 알 수 있습니다. 어떤 다항식을 $C_k(x)$로 나눈 나머지를 구하는 과정을 손으로 직접 계산해보면, 어떤 수 $c$를 적당한 구간 $[i-k, i-1]$에 더하는 연산을 여러 번 수행한다는 것을 알 수 있습니다. 우리는 이런 연산을 Fenwick Tree를 사용해 매우 빠르게 수행할 수 있습니다. 그러므로 다항식 나눗셈을 빠른 $O(k \log k)$에 할 수 있고, AC를 받을 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#pragma GCC optimize("O3")
#pragma GCC target("avx,avx2,fma")
#include <bits/stdc++.h>
#include <smmintrin.h>
#include <immintrin.h>
#define all(v) v.begin(), v.end()
#define sz(v) (int)(v.size())
using namespace std;

using ll = int32_t;
using poly = vector<ll>;

ll M;
int K, MOD;

int Norm(ll v){
    if(v < -MOD || v >= MOD) v %= MOD;
    return v < 0 ? v + MOD : v;
}
void Norm(poly &a){ for(auto &i : a) i = Norm(i); while(a.size() && !a.back()) a.pop_back(); }
inline int Add(int a, int b){ return a+b>=MOD ? a+b-MOD : a+b; }
inline int Sub(int a, int b){ return a >= b ? a-b+MOD : a-b; }

__m256d mult(__m256d a, __m256d b){
    __m256d c = _mm256_movedup_pd(a);
    __m256d d = _mm256_shuffle_pd(a, a, 15);
    __m256d cb = _mm256_mul_pd(c, b);
    __m256d db = _mm256_mul_pd(d, b);
    __m256d e = _mm256_shuffle_pd(db, db, 5);
    __m256d r = _mm256_addsub_pd(cb, e);
    return r;
}
void fft(int n, __m128d a[], bool invert){
    for(int i=1, j=0; i<n; ++i){
        int bit = n>>1;
        for(;j>=bit;bit>>=1) j -= bit;
        j += bit;
        if(i<j) swap(a[i], a[j]);
    }
    for(int len=2, lv=1; len<=n; len<<=1, lv++){
        double ang = 2*M_PI/len*(invert?-1:1);
        __m256d wlen; wlen[0] = cos(ang), wlen[1] = sin(ang);
        for(int i=0; i<n; i += len){
            __m256d w; w[0] = 1; w[1] = 0;
            for(int j=0; j<len/2; ++j){
                w = _mm256_permute2f128_pd(w, w, 0);
                wlen = _mm256_insertf128_pd(wlen, a[i+j+len/2], 1);
                w = mult(w, wlen);
                __m128d vw = _mm256_extractf128_pd(w, 1);
                __m128d u = a[i+j];
                a[i+j] = _mm_add_pd(u, vw);
                a[i+j+len/2] = _mm_sub_pd(u, vw);
            }
        }
    }
    if(invert){
        __m128d inv; inv[0] = inv[1] = 1.0/n;
        for(int i=0; i<n; ++i) a[i] = _mm_mul_pd(a[i], inv);
    }
}
vector<int64_t> multiply(const vector<int64_t>& v, const vector<int64_t>& w){
    int n = 2; while(n < v.size()+w.size()) n<<=1;
    __m128d* fv = new __m128d[n];
    for(int i=0; i<n; ++i) fv[i][0] = fv[i][1] = 0;
    for(int i=0; i<v.size(); ++i) fv[i][0] = v[i];
    for(int i=0; i<w.size(); ++i) fv[i][1] = w[i];
    fft(n, fv, 0);
    for(int i=0; i<n; i += 2){
        __m256d a;
        a = _mm256_insertf128_pd(a, fv[i], 0);
        a = _mm256_insertf128_pd(a, fv[i+1], 1);
        a = mult(a, a);
        fv[i] = _mm256_extractf128_pd(a, 0);
        fv[i+1] = _mm256_extractf128_pd(a, 1);
    }
    fft(n, fv, 1);
    vector<int64_t> ret(n);
    for(int i=0; i<n; ++i) ret[i] = (int64_t)llround(fv[i][1]/2);
    delete[] fv;
    return ret;
}

int tree[1<<13];
void update(int x, int v){ for(x+=2; x<(1<<15); x+=x&-x) tree[x] += v; }
int query(int x){ int ret = 0; for(x+=2; x>0; x-=x&-x) ret += tree[x]; return ret; }

poly mul(const poly &a, const poly &b){
    auto res = multiply(a, b);
    Norm(res);
    return res;
}

poly rem(const poly &a){
    if(K+1 > a.size()) return a;
    memset(tree, 0, sizeof tree);
    for(int i=0; i<a.size(); i++) update(i, a[i]), update(i+1, -a[i]);

    for(int i=a.size()-1; i>=K; i--){
        int t = Norm(query(i));
        update(i-K, t); update(i, -t);
    }

    poly ret(K);
    for(int i=0; i<K; i++) ret[i] = query(i);
    Norm(ret);
    return ret;
}

poly get(ll k){
    poly d = {1}, xn = {0, 1};
    while(k){
        if(k & 1) d = rem(mul(d, xn));
        k >>= 1; xn = rem(mul(xn, xn));
    }
    return d;
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> K >> M >> MOD;

    poly D;
    for(int i=0, pw=1; i<K; i++, pw=Add(pw, pw)){
        if(i <= K-2) D.push_back(pw);
        else D.push_back(Sub(pw, 1));
    }

    auto C = get(M);
    ll ans = 0;
    for(int i=0; i<K; i++) ans += C[i] * D[i];
    if(M < K) cout << D[M];
    else cout << Norm(ans);
}