Description
在字符串中恰好出现了 $k$ 次的子串中,按照字串的长度分类,求出现数量最多的那一类的长度。
Sol
一个串所有后缀的前缀可以表示所有的子串,一个串的后缀和后缀的 lcp 可以表示相同的子串。
用 SA 求出 height 了以后,一段 height 相等的就可以表示一些相同的子串。
那怎么算是不是恰好有 $k$ 个呢?现在用一个长度为 $k$ 的滑动窗口在整个串上滑,假如当前滑块内的 height 的最小值 $x \neq 0$ ,说明至少有 $k$ 个,否则少于 $k$ 个。
只需要判断, $height[i-k+1]$ 和 $height[i+1]$,如果这两个值中有值大于等于 $x$,这说明此时长度为 $x$ 的前缀超过了 $k$ 个。否则这样长度为 $x$ 的前缀恰好等于 $k$ 个。
因为长度在这两个端点的 $height$ 范围内的子串都是恰好出现 $k$ 次的。
发现出现的长度是一段区间,可以用差分数组维护。
Code
#include <bits/stdc++.h>
using namespace std;
const int SIZE = 2e5 + 5, SIZE_C = 130;
int t, n, k;
int q[SIZE], dif[SIZE];
char str[SIZE];
namespace sufSort {
int bak[SIZE], rk[SIZE], sa[SIZE], nsa[SIZE], nrk[SIZE], h[SIZE];
void init() {
for (int i = 0; i < SIZE; ++ i)
bak[i] = sa[i] = rk[i] = nrk[i] = nsa[i] = h[i] = 0;
}
void sort() {
init();
#define cmp(x, y) (rk[x] != rk[y] || rk[x + p] != rk[y + p])
for (int i = 0; i <= SIZE_C; ++ i) bak[i] = 0;
for (int i = 1; i <= n; ++ i) ++ bak[str[i]];
for (int i = 1; i <= SIZE_C; ++ i) bak[i] += bak[i - 1];
for (int i = 1; i <= n; ++ i) sa[bak[str[i]] --] = i;
for (int i = 1; i <= n; ++ i) rk[sa[i]] = rk[sa[i - 1]] + (str[sa[i]] != str[sa[i - 1]]);
for (int p = 1; p <= n; p <<= 1) {
for (int i = 1; i <= n; ++ i) bak[rk[sa[i]]] = i;
for (int i = n; i; -- i) {
if (sa[i] > p) nsa[bak[rk[sa[i] - p]] --] = sa[i] - p;
}
for (int i = n; i > n - p; -- i) nsa[bak[rk[i]] --] = i;
for (int i = 1; i <= n; ++ i) nrk[nsa[i]] = nrk[nsa[i - 1]] + cmp(nsa[i], nsa[i - 1]);
for (int i = 1; i <= n; ++ i) rk[i] = nrk[i], sa[i] = nsa[i];
if (rk[sa[n]] >= n) return;
}
}
void getH() {
for (int i = 1, j = 0; i <= n; ++ i) {
if (j) -- j;
while (str[i + j] == str[sa[rk[i] - 1] + j]) ++ j;
h[rk[i]] = j;
}
}
} using sufSort::sort; using sufSort::h; using sufSort::sa; using sufSort::getH;
int main() {
scanf("%d", &t);
while (t --) {
scanf("%s", str + 1); scanf("%d", &k);
n = (int) strlen(str + 1);
sort(); getH();
int head = 1, tail = 0;
for (int i = 0; i <= n + 5; ++ i) dif[i] = q[i] = 0;
for (int i = 1; i <= k; ++ i) {
while (head <= tail && h[q[tail]] >= h[i]) -- tail;
q[++ tail] = i;
}
for (int i = k; i <= n; ++ i) {
if (i - q[head] + 1 >= k) ++ head;
while (head <= tail && h[q[tail]] >= h[i]) -- tail;
q[++ tail] = i;
int l = std::max(h[i - k + 1], h[i + 1]), r = (k == 1 ? n - sa[i + k - 1] + 1 : h[q[head]]);
if (l < r) ++ dif[l + 1], -- dif[r + 1];
}
int sum = dif[0], mx = 1, ans = -1;
for (int i = 1; i <= n; ++ i) {
sum += dif[i];
if (sum >= mx) mx = sum, ans = i;
}
printf("%d\n", ans);
}
return 0;
}