扩展KMP
2022/4/2 Leetcode双周赛最后一题模板题, 虽然之前并不知道有这个玩意儿, 但是没写出来很难受, 这里学一下. 主要参考内容是OI WIKI, KMP的主要参考资料是算法4(algorithm 4).
KMP
先从KMP讲起, 将KMP视为有限状态机(DFA)可以更好理解
KMP的核心就是记录上一个与当前状态具有相同前缀(最大)的状态, 这样使得每次字符串不匹配时都尽可能不从头开始, 构建状态转移的代码如下
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
private:
vector<vector<int>> trans;
string pat;
int maxRange;
public:
KMP(string pat, int m)
{
if (pat == "")
{
trans.resize(1, vector<int>(maxRange, 0));
return;
}
this->maxRange = m;
this->pat = pat;
trans.resize(pat.size(), vector<int>(maxRange, 0));
// 初始状态
trans[0][pat[0] - 'a'] = 1;
// 拥有相同前缀的前一状态
int prevStatus = 0;
for (int curStatus = 1; curStatus < pat.size(); curStatus++)
{
for (int next = 0; next < maxRange; next++)
trans[curStatus][next] = trans[prevStatus][next];
// 当前状态进入下一状态
trans[curStatus][pat[curStatus] - 'a'] = curStatus + 1;
// 前一状态接受字符进入前一状态的下一状态
prevStatus = trans[prevStatus][pat[curStatus] - 'a'];
}
}
搜索过程就是正常的状态机运行过程
1
2
3
4
5
6
7
8
9
10
11
12
13
int search(string txt)
{
if (this->pat == "")
return 0;
int cur = 0;
for (int idx = 0; idx < txt.size(); idx++)
{
cur = trans[cur][txt[idx] - 'a'];
if (cur == pat.size())
return idx - pat.size() + 1;
}
return -1;
}
构造状态机的时间复杂度为$O(maxRange * pat.size())$, 空间复杂度同上, 每个状态只会搜索一次, 因此搜索的时间复杂度为$O(txt.size())$
扩展KMP(Z函数)
我们做出如下定义, 函数$z[i]$表示原字符串$s$与$s[i:end]$的最长公共前缀(LCP)的长度, 且规定$z[0] = 0$, 计算$z[i]$时我们会维护一个端点尽可能靠右的匹配段 $[L, R]$, 计算过程会有以下情况:
$[L, R]$这个匹配段的长度也是$z[i]$(我不知道为什么OI wiki上着重强调右端点靠右?), 总之你维护的右端点就是$1\to i$之间的最长的公共前缀所能到达的点
注意, 下方的运算过程中, 能够保证 $L \leq i$
- $i\leq R$: 由匹配段定义, $s[i:r] = s[(i - L):(R - L)]$, 则有$z[i] \geq \min{z[i - L], R - i + 1}$
- 若 $z[i - L] < R - i + 1$, 则 $z[i] = z[i - L]$
- 否则, $z[i] = R - i + 1$, 接着 R 逐步向右直至不满足条件(暴力枚举扩展)
$z[i]$ 的范围是显然的, 因为享有一段相同的公共前缀, $z[i] \geq z[i - L]$; 又由于R右侧的点还没探知, 因此 $z[i] \geq R - i + 1$ (肯定大于等于当前的最长公共前缀)
其次, 若 $z[i - L] < R - i + 1$, 则说明 $i - L$ 对应的 LCP 长度小于当前点至 $R$. 也就是说, 当前点已经找到 LCP, 右端点 R 无需向右扩展; 否则, 当前对应的 LCP 并未查找完全, 还应该向右暴力扩展判断
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 代码源自 OI wiki
vector<int> z_function(string s) {
int n = (int)s.length();
vector<int> z(n);
for (int i = 1, l = 0, r = 0; i < n; ++i) {
if (i <= r && z[i - l] < r - i + 1) {
z[i] = z[i - l];
} else {
z[i] = max(0, r - i + 1);
while (i + z[i] < n && s[z[i]] == s[i + z[i]]) ++z[i];
}
if (i + z[i] - 1 > r) l = i, r = i + z[i] - 1;
}
return z;
}