Luckyleaves's Blog

stay hungry,stay foolish greedy and lucky

CF1093F Vasya and Array

序列$dp$计数。

一个长度为n的序列,将序列中的-1替换为$1\sim k$中任意的数后,求不存在任意长度超过$len$的连续的相同字串的合法序列个数

Solution

首先我们定义$f_{i,j}$表示前$i$位,第$i$位填$j$,的合法方案数,又令$sum_i=\sum_{j=1}^kf_{i,j}$,当$i<len$时,此时不可能出现不合法的方案,自然$f_{i,j}=sum_{i-1}$,那么当$i\ge len$的时候,我们考虑每一次都将连续个数刚好等于$len$的贡献剪掉,可以发现由于我们每次都剩余贡献中减去非法贡献,每个非法贡献只会在第一次贡献中减去一次,所以这样做的正确性是有保障的。

接下来考虑怎么减,我们首先先检查当前$i-len+1 \sim i$是否可能出现全是$j$的区间,如果可能那么固定住这段区间,$sum_{i-len}-f_{i-len,j}$就是这段区间所造成的非法贡献(首先$sum_{i-len}$显然就是当前区间造成的总贡献,但是根据每个非法区间只在第一次非法的连续段右端点剪掉的原则,所以这段连续段的长度只能是$len$,所以它的非法贡献也就只有$sum_{i-len}-f_{i-len,j}$),注意$0$要特判。

复杂度$nk$。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include<iostream>
#include<vector>
#include<algorithm>
#include<cstring>
#include<ctime>
#include<cmath>
#include<cstdio>
#define LL long long
#define uLL unsigned LL
using namespace std;
template <class T>
inline void read(T &res)
{
res = 0; bool flag = 0;
char c = getchar();
while('0' > c || c > '9'){ if(c == '-') flag = 1; c = getchar(); }
while('0' <= c && c <= '9') res = (res << 3) + (res << 1) + (c ^ 48), c = getchar();
if(flag) res = -res;
}
template <class T, class ...ARC>
inline void read(T &res, ARC &...com) { read(res), read(com...); }
inline void write(char c) { putchar(c); }
inline void write(const char *s) { while(*s) putchar(*s ++); }
template <class T>
inline void write(T res)
{
if(res < 0) putchar('-'), res = -res;
if(res > 9) write(res / 10);
putchar(res % 10 + '0');
}
template <class T, class ...ARC>
inline void write(T res, ARC ...com) { write(res), write(com...); }
const int N = 1e5 + 5, mod = 998244353;
int n, k, len;
int a[N], f[N][105];
int cnt[N], tot;
inline void add(int x, int y)
{
if(!cnt[x]) tot ++;
cnt[x] += y;
if(!cnt[x]) tot --;
}
inline bool check(int x)
{
return (tot == 1 && cnt[x] > 0) || tot == 0;
}
int sum[N];
int main()
{
read(n, k, len);
if(len == 1) return puts("0"), 0;
for(int i = 1;i <= n;i ++) read(a[i]);
if(a[1] == -1)
for(int i = 1;i <= k;i ++) f[1][i] = 1;
else f[1][a[1]] = 1;
if(a[1] != -1) add(a[1], 1);
sum[0] = 1;
for(int i = 1;i <= k;i ++) sum[1] += f[1][i];
for(int i = 2, res;i <= n;i ++)
{
if(a[i] != -1) add(a[i], 1);
if(i > len && a[i - len] != -1) add(a[i - len], -1);
if(a[i] == -1)
{
for(int j = 1;j <= k;j ++) f[i][j] = sum[i - 1];
if(i >= len)
{
for(int j = 1;j <= k;j ++)
if(check(j)) f[i][j] = (f[i][j] + mod - sum[i - len]) % mod,
f[i][j] = (f[i][j] + f[i - len][j]) % mod;
}
}
else{
for(int j = 1;j <= k;j ++) f[i][a[i]] = sum[i - 1];
if(i >= len && check(a[i]))
{
f[i][a[i]] = (f[i][a[i]] + mod - sum[i - len]) % mod;
f[i][a[i]] = (f[i][a[i]] + f[i - len][a[i]]) % mod;
}
}
for(int j = 1;j <= k;j ++) sum[i] = (sum[i] + f[i][j]) % mod;
}
write(sum[n]);
return 0;
}