SIMD in PS - ICPC Seoul Regional 2021 L. Trio

서론

2021년 서울 리저널 L번 문제로 출제된 Trio는 다양한 풀이가 존재합니다. 의도된 풀이는 포함 배제를 사용하는 $O(N^2\cdot 81)$ 정도의 풀이로 추측되지만 $O(N^3)$에 비례하는 풀이, $O(N^2 \cdot 10\,000)$를 bitset으로 최적화한 등 다양한 풀이가 나왔습니다. 이 글에서는 SIMD를 이용한 두 가지 풀이를 알아보면서 PS에 SIMD를 어떻게 적용할 수 있을지 살펴보겠습니다.

SIMD란?

자세한 내용은 다루지 않고, SIMD가 무엇인지 알 수 있을 정도로만 정리하고 넘어가겠습니다.

SIMD는 Single Instruction Multiple Data의 약자로, 하나의 명령어로 여러 데이터에 대해 동일한 연산을 동시에 수행하는 방식입니다. 예를 들어 아래 코드는 8번의 덧셈 연산을 수행해야 하지만, 뒤에서 설명할 avx를 이용하면 한 번의 명령으로 8번의 덧셈을 동시에 수행할 수 있습니다.

1
2
int a[8] = {0, 1, 2, 3, 4, 5, 6, 7}, b[8] = {1, 2, 3, 4, 4, 3, 2, 1}, c[8];
for(int i=0; i<8; i++) c[i] = a[i] + b[i]; // c = { 1, 3, 5, 7, 8, 8, 8, 8 }

Codeforces에서 고인물들의 코드를 구경하다 보면 #pragma GCC target("sse,sse2,sse3,ssse3,sse4,avx,avx2") 와 같은 코드를 심심치 않게 볼 수 있습니다. sse와 avx 같은 것들은 SIMD 명령어셋으로, sse 시리즈는 128bit 단위, avx 시리즈는 256bit 단위로 병렬 처리를 할 수 있습니다. 즉, avx 명령어셋을 사용하면 int형(32bit) 덧셈 8번을 동시에 할 수 있습니다.
이런 명령어셋의 지원 여부는 CPU마다 다르기 때문에 컴파일러는 기본적으로 SIMD 명령어를 사용하지 않도록 컴파일합니다. 위 코드는 해당 명령어셋을 사용해서 컴파일하도록 컴파일러에게 지시하는 구문입니다. 만약 컴파일 옵션을 직접 입력할 수 있다면 -msse -mavx -mavx2와 같이 옵션을 줘도 됩니다. 아래 코드는 avx를 이용해서 길이가 8인 두 배열의 덧셈을 수행하는 코드입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#pragma GCC target ("avx2")
#include <bits/stdc++.h>
#include <x86intrin.h>
using namespace std;

alignas(32) int a[8] = {0, 1, 2, 3, 4, 5, 6, 7}, b[8] = {1, 2, 3, 4, 4, 3, 2, 1}, c[8];

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    __m256i av = _mm256_load_si256((const __m256i*)a);
    __m256i bv = _mm256_load_si256((const __m256i*)b);
    __m256i cv = _mm256_add_epi32(av, bv);
    _mm256_store_si256((__m256i*)c, cv);
    for(int i=0; i<8; i++) cout << c[i] << " ";
}

__m256i는 정수 자료형(char, short, int , long long 등)을 저장하는 256bit 크기의 벡터입니다. __mm256_load_si256은 align 되어 있으면서 메모리에서 연속한 256bit 를 가져오는 함수이고, __m256_add_epi32는 두 __m256i를 32bit 기준으로 더하는 함수입니다. load 명령이 메모리에서 연속한 데이터를 가져온다는 것에 주목해주세요.

사실 요즘 컴파일러는 똑똑해서 위 코드처럼 직접 SIMD 명령을 작성하지 않더라도, target을 지정해주고 O3 최적화 옵션을 주면 적당히 잘 컴파일해줍니다.

SIMD 명령어셋 별로 지원하는 연산은 Intel Intrinsics Guide에서 확인할 수 있고, 자신의 CPU에서 지원하는 명령어셋은 CPU-Z를 이용해 확인할 수 있습니다. 참고로 BOJ는 AVX2까지 지원하고, ICPC Seoul Regional은 AVX2를 지원하는 것은 확인했는데 AVX512는 확인하지 못했습니다.

$O(N^3/96)$

가장 먼저 떠올릴 수 있는 Naive한 풀이는 $i < j < k$를 잡아서 $(A_i, A_j), (A_j, A_k), (A_k, A_i)$의 공통 부분이 모두 같은지 확인하는 $O(N^3)$ 풀이입니다. $C(i, j)$를 $(A_i, A_j)$의 공통 부분이라고 정의하면 $C(i, j) = C(j, k) = C(k, i)$가 모두 동일한지 확인하면 됩니다. $i < j < k$ 조건이 붙어있기 때문에 실제로 확인하는 $(i,j,k)$ 튜플의 개수는 대략 $N^3/6$입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <bits/stdc++.h>
using namespace std;

int N, C[2000][2020];
string A[2000];
long long R;

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> N;
    for(int i=0; i<N; i++) cin >> A[i];
    for(int i=0; i<N; i++){
        for(int j=0; j<i; j++){
            int now = 0;
            for(int k=0; k<4; k++){
                now = now * 10 + (A[i][k] == A[j][k] ? A[i][k] - '0' : 0);
            }
            C[i][j] = C[j][i] = now;
        }
    }
    for(int i=0; i<N; i++){
        for(int j=0; j<i; j++){
            int key = C[i][j];
            for(int k=0; k<j; k++) R += key == C[i][k] && key == C[j][k];
        }
    }
    cout << R;
}

이때 23번째 줄을 잘 보시면 key와 비교하는 데이터들이 메모리에서 연속하기 때문에 SIMD를 적용할 수 있습니다. 또한 모든 수는 32768보다 작기 때문에 short형을 써도 무방하고, 16번의 연산을 동시에 처리할 수 있으므로 시간 복잡도는 $O(N^3/6/16) = O(N^3/96)$이 됩니다. 위 코드에 O3 옵션과 AVX2 타겟만 지정해도 됩니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#pragma GCC optimize ("O3")
#pragma GCC target ("avx2")
#include <bits/stdc++.h>
#include <x86intrin.h>
using namespace std;

alignas(32) short N, C[2000][2000];
string A[2000];
long long R;

void Solve(int u, int v){
    static short tmp[16], i;
    __m256i key = _mm256_set1_epi16(C[u][v]);
    __m256i res = _mm256_setzero_si256();
    for(i=0; i+16<=u; i+=16){
        __m256i now_u = _mm256_load_si256((const __m256i*)(C[u]+i)); // C[u][i] ~ C[u][i+15]
        __m256i now_v = _mm256_load_si256((const __m256i*)(C[v]+i)); // C[v][i] ~ C[v][i+15]
        now_u = _mm256_cmpeq_epi16(now_u, key);                      // C[u][k] == key
        now_v = _mm256_cmpeq_epi16(now_v, key);                      // C[v][k] == key
        now_v = _mm256_and_si256(now_v, now_u);                      // C[u][k] == key && C[v][k] == key
        res = _mm256_sub_epi16(res, now_v);                          // res[0] += now_v[0], res[1] += now_v[1], ...
    }
    while(i < u) R += C[u][v] == C[u][i] && C[u][v] == C[v][i], i++; // 16개씩 처리하고 남은 원소들 처리 (최대 15개)
    _mm256_store_si256((__m256i*)tmp, res);
    for(i=0; i<16; i++) R += tmp[i];
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> N;
    for(int i=0; i<N; i++) cin >> A[i];
    for(int i=0; i<N; i++){
        for(int j=0; j<i; j++){
            int now = 0;
            for(int k=0; k<4; k++){
                now = now * 10 + (A[i][k] == A[j][k] ? A[i][k] - '0' : 0);
            }
            C[i][j] = C[j][i] = now;
        }
    }
    for(int i=0; i<N; i++) for(int j=0; j<i; j++) Solve(j, i);
    cout << R;
}

기본적인 흐름은 다음과 같습니다.

  1. C[u][v] == C[u][i]를 SIMD로 계산 (now_u)
  2. C[u][v] == C[v][i]를 SIMD로 계산 (now_v)
  3. now_vnow_u를 AND 연산
  4. 정답에 반영

_mm256_cmpeq_epi16은 다르면 0000..., 같으면 1111...을 반환하기 때문에 true를 -1로 취급할 수 있으므로 res_mm256_sub_epi16을 취합니다.

$O(N^2\cdot 20000/512)$

이 풀이는 제가 대회 중에 작성한 풀이로, $N^3$풀이보다 조금 복잡합니다. bitset을 알고 있다고 가정하고 제 풀이를 먼저 설명합니다.

수 2개를 고정하면 각 자리는 “무조건 x가 와야 하는자리” / “y z 빼고 아무거나 와도 되는 자리” 중 하나에 해당합니다. 예를 들어 고정된 수가 1234, 1243이라면 1/2번째 자리에는 무조건 1/2가 와야 하고, 3/4번째 자리에는 3/4 빼고 아무거나 와도 됩니다.
즉, 12?? 꼴을 만족하는 수와 { ??3?, ???4, ??4?, ???3 } 꼴을 만족하지 않는 수의 교집합의 크기를 구하면 됩니다. 구현의 편의를 위해 0을 와일드 카드 문자로 취급합시다.

B[i][j]를 j가 패턴 i와 매칭되면 1, 안 되면 0으로 정의합시다. bitset<10000> B[10000]처럼 bitset을 이용하면 위 예시는 B[1200] & ~B[30] & ~B[4] & ~B[40] & ~B[3]으로 구할 수 있고, 최대 8번의 AND 연산을 사용하므로 시간 복잡도는 $O(N^2 \cdot 10000 \cdot 8 / 2w)$입니다. 이때 $w$는 word size(보통 64)이고, $N \choose 2$개의 pair만 보면 되므로 $/2$가 추가로 붙게 됩니다.
이때, ~B[x]꼴에서 x에 와일드 카드가 아닌 인덱스가 하나 밖에 없다는 점을 이용해서 최적화를 할 수 있습니다.

D[i][j]를 j와 i가 한 자리 이상 겹치면 0, 안 겹치면 1으로 정의합시다. 예를 들어 1234라면 1xxx, x2xx, xx3x, xxx4, 총 9000개의 칸을 0으로 세팅합니다.
D 배열도 전처리해두면 위 예시는 B[1200] & D[34] & D[43]으로 구할 수 있고, AND 연산을 2번만 사용하므로 시간 복잡도는 $O(N^2 \cdot 10000\cdot 2 / 2w)$가 됩니다.

bitset을 그냥 사용하면 $w=64$라서 $O(N^2\cdot 20000/128)$이지만, AVX를 사용하면 $w=256$이므로 시간 복잡도는 $O(N^2\cdot 20000/512)$가 됩니다. 아래 코드는 단순이 O3 옵션과 AVX2 타겟을 지정한 코드입니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#pragma GCC optimize ("O3")
#pragma GCC target ("avx2")
#include <bits/stdc++.h>
using namespace std;

alignas(32) int N, A[2000];
bitset<10000> B[10000], D[10000];
long long R;

array<int,4> Decomp(int n){
    array<int,4> arr;
    arr[0] = n / 1000 * 1000; n %= 1000;
    arr[1] = n / 100 * 100; n %= 100;
    arr[2] = n / 10 * 10; n %= 10;
    arr[3] = n / 1 * 1; n %= 1;
    return arr;
}

void Init(int n){
    auto arr = Decomp(n);
    for(int bit=0; bit<16; bit++){
        int now = 0;
        for(int i=0; i<4; i++) if(bit & (1 << i)) now += arr[i];
        B[now].set(n);
    }
    for(int p=0; p<4; p++){
        vector<int> p10 = {1000, 100, 10, 1};
        p10.erase(p10.begin() + p);
        for(int i=0; i<10; i++){
            for(int j=0; j<10; j++){
                for(int k=0; k<10; k++){
                    int now = arr[p] + i * p10[0] + j * p10[1] + k * p10[2];
                    D[now].reset(n);
                }
            }
        }
    }
}

void Solve(int a, int b){
    auto a_v = Decomp(a), b_v = Decomp(b);
    int mask = 0, am = a, bm = b;
    for(int i=0; i<4; i++) if(a_v[i] == b_v[i]) mask += a_v[i], am -= a_v[i], bm -= b_v[i];
    auto res = B[mask] & D[am] & D[bm];
    R += res.count() - res[a] - res[b];
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> N;
    for(int i=0; i<N; i++) cin >> A[i];
    for(int i=0; i<10000; i++) D[i].set();
    for(int i=0; i<N; i++) Init(A[i]);
    for(int i=0; i<N; i++) for(int j=0; j<i; j++) Solve(A[i], A[j]);
    cout << R / 3;
}

AVX2로 직접 구현한 코드는 여기에서 확인하실 수 있습니다.