CodeForces 1015F - Bracket Substring (KMP + dp)


题目链接

题意:
给你一个括号序列 $s$ 和一个数 $n$ ,让你求出长度为 $2 * n$ 并且 $s$ 是最后串里面的一个子串的合法括号序列有多少种。$mod (1e9 + 7)$。

参考 $blog$

分析:

先确定 $dp[i][j][k]$ 表示最后的串到第 $i$ 位,后缀匹配到长度为 $j$ 的原串,有 $k$ 个未匹配的左括号的方案数有多少个。

为什么要这样定义状态呢,如果只是单纯地定义 $dp[i][k]$ 的话,你不能够确定这些串里面是否含有 $s$(感觉是废话)。

那么转移就是:
$dp[i][j][k]$ $+=$ $\sum_{x}^{所有可能的后缀情况} dp[i - 1][k][x][看当前为是 ‘(‘ 还是 ‘)’,来决定]$

但是这样的转移好像很麻烦,我们换成

$dp[i][当前位填’(‘后对应的原串后缀长度][k + 1]$ $+=$ $dp[i - 1][j][k]$
$dp[i][当前位填’)’后对应的原串后缀长度][k]$ $+=$ $dp[i - 1][j][k + 1]$

是不是感觉很简单,然后这个 当前位填 ‘(‘ or ‘)’ 后对应的原串后缀长度 可以用 $KMP$ 的 $next$ 数组来优化,我直接预处理出来一个数组 $to[x][i]$ 表示当前串后缀已经和原串匹配了长度 $x$ ,接下来填 $‘(’$ $(i = 0)$,$‘)’$ $(i = 1)$ 后,后缀和原串对应的匹配长度。然后就可以愉快的 $dp$ 了。

当匹配到长度和原串相同时,就直接转移到 $len$ 统计答案。

最后答案就是 $dp[2 * n][len][0]$ 。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <bits/stdc++.h>
#include <ext/rope>
using namespace __gnu_cxx;
using namespace std;
#define mst(a,b) memset(a,b,sizeof(a))
#define ALL(x) x.begin(),x.end()
#define pii pair<int, int>
#define debug(a) cout << #a": " << a << endl;
#define eularMod(a, b) a < b ? a : a % b + b
inline int lowbit(int x){ return x & -x; }
typedef long long LL;
typedef unsigned long long ULL;
const int N = 1e5 + 10;
const long long mod = 1000000007;
const int INF = 0x3f3f3f3f;
const long long LINF = 0x3f3f3f3f3f3f3f3fLL;
const double PI = acos(-1.0);
const double eps = 1e-6;

char s[205];
int n;
int nex[205];
int to[205][2];
int dp[205][205][205];

int main() {
#ifdef purple_bro
freopen("in.txt", "r", stdin);
// freopen("out.txt","w",stdout);
#endif
scanf("%d%s", &n, s);

n <<= 1;

int len = strlen(s);

nex[0] = 0;
nex[1] = 0;

for (int i = 1; i < len; i++) {
int tmp = nex[i];

for (;tmp && s[tmp] != '(';)
tmp = nex[tmp];

to[i][0] = (s[tmp] == '(') ? tmp + 1 : 0;

tmp = nex[i];

for (;tmp && s[tmp] != ')';)
tmp = nex[tmp];

to[i][1] = (s[tmp] == ')') ? tmp + 1 : 0;

int j = nex[i];

for (;j && s[j] != s[i];)
j = nex[j];

nex[i + 1] = (s[i] == s[j]) ? j + 1 : 0;
}

for (int i = 1; i <= n; i++) {
if (i == 1) {
if (s[0] == '(')
dp[i][1][1] = 1;
else
dp[i][0][1] = 1;
continue;
}

for (int j = 0; j <= len && j < i; j++) {
for (int k = 0; k < i; k++) {
if (j < len) {
if (s[j] == '(')
(dp[i][j + 1][k + 1] += dp[i - 1][j][k]) %= mod;
else
(dp[i][to[j][0]][k + 1] += dp[i - 1][j][k]) %= mod;

if (s[j] == ')')
(dp[i][j + 1][k] += dp[i - 1][j][k + 1]) %= mod;
else
(dp[i][to[j][1]][k] += dp[i - 1][j][k + 1]) %= mod;
} else {
(dp[i][j][k + 1] += dp[i - 1][j][k]) %= mod;
(dp[i][j][k] += dp[i - 1][j][k + 1]) %= mod;
//只转移到 len
}
}
}
}

printf("%d\n", dp[n][len][0]);

return 0;
}