https://atcoder.jp/contests/abc272/tasks/abc272_f
Construct a new string:
$$X=S+S+a\times n+T+T+z\times n$$
Run suffix array algorithm on $X$, get $sa()$ and $rk()$.
$$g(S,i)=rk(i)$$
$$g(T,i)=rk(3n+i)$$
Claim: $f(S,i)\leq f(T,j)\iff g(S,i)< g(T,j)$
Proof:
Let $S’,T’$ be $S,T$ after rotating $i,j$ bits, let $S=preS(i)+sufS(i), T=preT(j)+sufT(j)$
The length of each suffix $>n$ since there are $n\times z$ at the end.
Therefore if $f(S,i)<f(T,j)$ then $\exists k,S'[k]<T'[k]$.
If $f(S,i)=f(T,j)$ then $sufS(i)+preS(i)=sufT(j)+preT(j)$, and
$$suffix(i)=sufS(i)+S+a\times n+T+T+z\times n$$
$$suffix(j)=sufT(j)+T+z\times n$$
If $i>j$ then the first different bit is $a$ after $S$; If $i<j$ then the first different bit is $z$ after $T$; If $i=j$ then the first different bits are $a$ and $z$.
In either case, $suffix(i)<suffix(j)$
Likewise, we can prove the backward.
// compile: g++ -o data data.cpp -O3 -std=gnu++20 -Wall -Wextra -Wshadow -D_GLIBCXX_ASSERTIONS -ggdb3 -fmax-errors=2 -DLOCAL
// run: ./data < data.in
#include <bits/stdc++.h>
using namespace std;
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#ifdef LOCAL
#include <debug/codeforces.h>
#define debug(x...) {_variables(#x);_print(x);}
#else
#define debug(x...)
#endif
template<typename...Args> void print_(Args...args){((cout<<args<<" "),...)<<endl;}
#define rep(i,a,b) for(int i=(a);i<(int)(b);++i)
#define sz(v) ((int)(v).size())
#define print(...) print_(__VA_ARGS__);
#define INTINF (int)(9223372036854775807)
#define int long long
#define MAXN 200010
std::vector<int> sa_naive(const std::vector<int>& s) {
int n = (int)(s.size());
std::vector<int> sa(n);
std::iota(sa.begin(), sa.end(), 0);
std::sort(sa.begin(), sa.end(), [&](int l, int r) {
if (l == r) return false;
while (l < n && r < n) {
if (s[l] != s[r]) return s[l] < s[r];
l++;
r++;
}
return l == n;
});
return sa;
}
std::vector<int> sa_doubling(const std::vector<int>& s) {
int n = (int)(s.size());
std::vector<int> sa(n), rnk = s, tmp(n);
std::iota(sa.begin(), sa.end(), 0);
for (int k = 1; k < n; k *= 2) {
auto cmp = [&](int x, int y) {
if (rnk[x] != rnk[y]) return rnk[x] < rnk[y];
int rx = x + k < n ? rnk[x + k] : -1;
int ry = y + k < n ? rnk[y + k] : -1;
return rx < ry;
};
std::sort(sa.begin(), sa.end(), cmp);
tmp[sa[0]] = 0;
for (int i = 1; i < n; i++) {
tmp[sa[i]] = tmp[sa[i - 1]] + (cmp(sa[i - 1], sa[i]) ? 1 : 0);
}
std::swap(tmp, rnk);
}
return sa;
}
template <int THRESHOLD_NAIVE = 10, int THRESHOLD_DOUBLING = 40>
std::vector<int> sa_is(const std::vector<int>& s, int upper) {
int n = (int)(s.size());
if (n == 0) return {};
if (n == 1) return {0};
if (n == 2) {
if (s[0] < s[1]) {
return {0, 1};
} else {
return {1, 0};
}
}
if (n < THRESHOLD_NAIVE) {
return sa_naive(s);
}
if (n < THRESHOLD_DOUBLING) {
return sa_doubling(s);
}
std::vector<int> sa(n);
std::vector<bool> ls(n);
for (int i = n - 2; i >= 0; i--) {
ls[i] = (s[i] == s[i + 1]) ? ls[i + 1] : (s[i] < s[i + 1]);
}
std::vector<int> sum_l(upper + 1), sum_s(upper + 1);
for (int i = 0; i < n; i++) {
if (!ls[i]) {
sum_s[s[i]]++;
} else {
sum_l[s[i] + 1]++;
}
}
for (int i = 0; i <= upper; i++) {
sum_s[i] += sum_l[i];
if (i < upper) sum_l[i + 1] += sum_s[i];
}
auto induce = [&](const std::vector<int>& lms) {
std::fill(sa.begin(), sa.end(), -1);
std::vector<int> buf(upper + 1);
std::copy(sum_s.begin(), sum_s.end(), buf.begin());
for (auto d : lms) {
if (d == n) continue;
sa[buf[s[d]]++] = d;
}
std::copy(sum_l.begin(), sum_l.end(), buf.begin());
sa[buf[s[n - 1]]++] = n - 1;
for (int i = 0; i < n; i++) {
int v = sa[i];
if (v >= 1 && !ls[v - 1]) {
sa[buf[s[v - 1]]++] = v - 1;
}
}
std::copy(sum_l.begin(), sum_l.end(), buf.begin());
for (int i = n - 1; i >= 0; i--) {
int v = sa[i];
if (v >= 1 && ls[v - 1]) {
sa[--buf[s[v - 1] + 1]] = v - 1;
}
}
};
std::vector<int> lms_map(n + 1, -1);
int m = 0;
for (int i = 1; i < n; i++) {
if (!ls[i - 1] && ls[i]) {
lms_map[i] = m++;
}
}
std::vector<int> lms;
lms.reserve(m);
for (int i = 1; i < n; i++) {
if (!ls[i - 1] && ls[i]) {
lms.push_back(i);
}
}
induce(lms);
if (m) {
std::vector<int> sorted_lms;
sorted_lms.reserve(m);
for (int v : sa) {
if (lms_map[v] != -1) sorted_lms.push_back(v);
}
std::vector<int> rec_s(m);
int rec_upper = 0;
rec_s[lms_map[sorted_lms[0]]] = 0;
for (int i = 1; i < m; i++) {
int l = sorted_lms[i - 1], r = sorted_lms[i];
int end_l = (lms_map[l] + 1 < m) ? lms[lms_map[l] + 1] : n;
int end_r = (lms_map[r] + 1 < m) ? lms[lms_map[r] + 1] : n;
bool same = true;
if (end_l - l != end_r - r) {
same = false;
} else {
while (l < end_l) {
if (s[l] != s[r]) {
break;
}
l++;
r++;
}
if (l == n || s[l] != s[r]) same = false;
}
if (!same) rec_upper++;
rec_s[lms_map[sorted_lms[i]]] = rec_upper;
}
auto rec_sa =
sa_is<THRESHOLD_NAIVE, THRESHOLD_DOUBLING>(rec_s, rec_upper);
for (int i = 0; i < m; i++) {
sorted_lms[i] = lms[rec_sa[i]];
}
induce(sorted_lms);
}
return sa;
}
std::vector<int> suffix_array(const std::string& s) {
int n = (int)(s.size());
std::vector<int> s2(n);
for (int i = 0; i < n; i++) {
s2[i] = s[i];
}
return sa_is(s2, 255);
}
int32_t main() {
ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr);
int n; string S, T;
cin >> n >> S >> T;
string as(n, 'a'), zs(n, 'z');
string str = S + S + as + T + T + zs;
vector<int> sa = suffix_array(str);
int m = sz(sa);
vector<int> rk(m);
rep(i, 0, m) rk[sa[i]] = i;
vector<int> num(m), sum(m);
rep(i, 0, n) ++num[rk[i]];
rep(i, 0, m) sum[i] = (i ? sum[i-1] + num[i] : num[i]);
int ans = 0;
rep(i, 3*n, 4*n) ans += sum[rk[i]];
cout << ans << endl;
return 0;
}
Consider the suffix array when keywords such as lexicographic order, prefix, and suffix appears.
Consider concat two strings when keywords such as rotation, cut…