Link

Sol

你感觉前面那个 $len(T_i),len(T_j)$ 这两项肯定是拆开直接算的,不假,这就是算所有后缀的长度乘上 $n-1$,那么前面两项就应该等于 $\frac{(n+1)(n-1)n}{2}$

然后现在要计算后面那一坨,肯定先求出 height

某个后缀与其他后缀相同长度的可看成是以 $i$ 位置最大的,向两边单调不增的序列。而且下降的位置一定是出现了一个比它还要小的 height。

考虑每一个 height 的贡献,那就是前面第一个比它大的那个位置到后面第一个那个比它小的位置这一段,再乘上 height。那么这只需要一个单调栈维护即可。

Code

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
const int SIZE = 5e5 + 5, SIZE_C = 130;

int n;
int sta[SIZE], l[SIZE], r[SIZE];
char str[SIZE];

namespace sufSort {
    int bak[SIZE], sa[SIZE], rk[SIZE], nsa[SIZE], nrk[SIZE], h[SIZE];

    void sort() {
#define cmp(x, y) (rk[x] != rk[y] || rk[x + p] != rk[y + p])
        for (int i = 0; i <= SIZE_C; ++ i) bak[i] = 0;
        for (int i = 1; i <= n; ++ i) ++ bak[str[i]];
        for (int i = 1; i <= SIZE_C; ++ i) bak[i] += bak[i - 1];
        for (int i = 1; i <= n; ++ i) sa[bak[str[i]] --] = i;
        for (int i = 1; i <= n; ++ i) rk[sa[i]] = rk[sa[i - 1]] + (str[sa[i]] != str[sa[i - 1]]);
        for (int p = 1; p <= n; p <<= 1) {
            for (int i = 1; i <= n; ++ i) bak[rk[sa[i]]] = i;
            for (int i = n; i; -- i) {
                if (sa[i] > p) nsa[bak[rk[sa[i] - p]] --] = sa[i] - p;
            }
            for (int i = n; i > n - p; -- i) nsa[bak[rk[i]] --] = i;
            for (int i = 1; i <= n; ++ i) nrk[nsa[i]] = nrk[nsa[i - 1]] + cmp(nsa[i], nsa[i - 1]);
            for (int i = 1; i <= n; ++ i) rk[i] = nrk[i], sa[i] = nsa[i];
            if (rk[sa[n]] >= n) return;
         }
    }

    void get() {
        for (int i = 1, j = 0; i <= n; ++ i) {
            if (j) -- j;
            while (str[i + j] == str[sa[rk[i] - 1] + j]) ++ j;
            h[rk[i]] = j;
        }
    }
} using namespace sufSort;

int main() {
    scanf("%s", str + 1); n = (int) strlen(str + 1);
    sort(); get();
    int top = 0; sta[top = 1] = 1;
    for (int i = 2; i <= n; ++ i) {
        while (top && h[sta[top]] > h[i]) r[sta[top --]] = i;
        l[i] = sta[top], sta[++ top] = i;
    }
    while (top) r[sta[top --]] = n + 1;
    LL ans = 1ll * (n + 1) * (n - 1) * n / 2;
    for (int i = 2; i <= n; ++ i) ans -= 2ll * (r[i] - i) * (i - l[i]) * h[i];
    printf("%lld\n", ans);
    return 0;
}
最后修改:2021 年 09 月 07 日
如果觉得我的文章对你有用,请随意赞赏