백준10066 팰린드롬

문제 링크

  • http://icpc.me/10066

문제 출처

  • 2014 APIO 1번

사용 알고리즘

  • manacher
  • tree
  • hashing

시간복잡도

  • $O(N log N)$
  • unordered map 이용시 O(N)

풀이

manacher를 돌려주고 시작합시다. manacher를 이용하면 팰린드롬인 구간을 모두 알아낼 수 있습니다. 팰린드롬인 구간은 최대 $O(N^2)$개입니다.
해싱을 하고 map을 이용해서 각 팰린드롬이 나오는 횟수를 계산해주면 $O(N^2 log N)$에 풀 수 있습니다. unordered_map을 이용하면 $O(N^2)$에 풀 수 있습니다.

어떤 문자열에서 서로 다른 팰린드롬의 개수는 최대 N개라는 것을 이용해서 이 풀이를 $O(N log N)$으로 줄일 수 있습니다.Br> [s, e]가 팰린드롬이면 [s+1, e-1]도 팰린드롬입니다. [s, e]의 해시값을 [s+1, e-1]의 해시값과 연결해줍시다. [s+1, e-1]은 [s+2, e-2]와, [s+2, e-2]는 [s+3, e-3]…을 반복해서 이미 존재하는 해시값이 나올 때까지 반복해서 간선을 만들어줍시다. 이렇게 트리(포레스트)를 만들면 최대 N개의 정점이 생성됩니다.

팰린드롬인 구간 [s, e]가 등장할 때마다 [s, e]의 조상들의 등장 횟수를 각각 1씩 증가시키면 당연히 $O(N^2)$입니다.
하지만 그럴 필요 없이, [s, e]에만 체크를 해줘도 dfs를 이용해 서브트리에 속한 각 정점의 등장 횟수를 더해줘도 됩니다.

map을 사용하면 $O(N log N)$, unordered_map을 사용하면 $O(N)$이 됩니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const ll p1 = 179, mod1 = 1e9-63;
const ll p2 = 917, mod2 = 1e9+7;

int n;
char s[606060];
int dp[606060];
ll ha1[606060], pw1[606060];
ll ha2[606060], pw2[606060];

int pv = 0;
map<pair<ll, ll>, int> mp;
int cnt[303030];
int len[303030];
int lnk[303030];
vector<int> g[303030];

void manacher(){
    for(int i=n-1; i>=0; i--){
        s[i << 1 | 1] = s[i];
        s[i << 1] = '#';
    }
    n <<= 1; s[n++] = '#';
    int r = -1, p = -1;
    for(int i=0; i<n; i++){
        dp[i] = (i <= r ? min(dp[p*2-i], r-i) : 0);
        while(i-dp[i]-1 >= 0 && i+dp[i]+1 < n && s[i-dp[i]-1] == s[i+dp[i]+1]) dp[i]++;
        if(i+dp[i] > r) r = i+dp[i], p = i;
    }
}
pair<ll, ll> substr(int l, int r){
    ll r1 = ha1[l] - ha1[r+1] * pw1[r-l+1]; r1 %= mod1; if(r1 < 0) r1 += mod1;
    ll r2 = ha2[l] - ha2[r+1] * pw2[r-l+1]; r2 %= mod2; if(r2 < 0) r2 += mod2;
    return {r1, r2};
}

void go(int s, int e, pair<ll, ll> now){
    if(e-s+1 <= 2) return;
    auto par = substr(s+1, e-1);
    int v = mp[now], u;
    if(!mp.count(par)){
        u = mp[par] = ++pv; len[u] = e-s-1;
    }else u = mp[par];
    if(lnk[v]) return; lnk[v] = 1;
    g[u].push_back(v);
    go(s+1, e-1, par);
}

int sz[303030];
int chk[303030];
void dfs(int v){
    sz[v] = cnt[v];
    chk[v] = 1;
    for(auto i : g[v]){
        if(chk[i]) continue;
        dfs(i); sz[v] += sz[i];
    }
}

int main(){
    ios_base::sync_with_stdio(0); cin.tie(0);
    cin >> s; n = strlen(s);
    pw1[0] = 1, pw1[1] = p1;
    pw2[0] = 1, pw2[1] = p2;
    for(int i=n-1; i>=0; i--) ha1[i] = (ha1[i+1] * p1 + s[i]) % mod1;
    for(int i=2; i<=n; i++) pw1[i] = pw1[i-1] * p1 % mod1;
    for(int i=n-1; i>=0; i--) ha2[i] = (ha2[i+1] * p2 + s[i]) % mod2;
    for(int i=2; i<=n; i++) pw2[i] = pw2[i-1] * p2 % mod2;
    manacher();

    for(int i=0; i<n; i++){
        if(!dp[i]) continue;
        int s, e;
        if(i & 1) s = i/2-dp[i]/2, e = i/2+dp[i]/2;
        else{
            s = i-1, e = i+1; s /= 2, e /= 2;
            s -= dp[i]/2-1;
            e += dp[i]/2-1;
        }
        auto now = substr(s, e);
        if(!mp.count(now)){
            mp[now] = ++pv; len[mp[now]] = e-s+1;
        }
        cnt[mp[now]]++;
        go(s, e, now);
    }

    ll ans = 0;
    for(int i=1; i<=pv; i++) if(!sz[i]) dfs(i);
    for(int i=1; i<=pv; i++){
        ans = max(ans, (ll)len[i] * sz[i]);
    }
    cout << ans;
};

// banana