Luckyleaves's Blog

stay hungry,stay foolish greedy and lucky

ARC148E

插入$dp$的神仙题。

想了一晚上无果,mydcwfy讲了10min+才给我讲懂,在此orz

Solution

首先可以糊一个$n^2$的插入$dp$,显然卡常也救不活了,所以我们模拟插入$dp$的思路,可以发现将数列排序之后,从数列的两端开始插入,按照以下方式插入:

  • $a_l+a_r\ge k$将$a_r$插入,并增加段。
  • $a_l+a_r< k$将$a_l$插入,并减少段。

为什么这样构造呢,因为这样就会发现如果当前的$a_l+a_r<k$那么意味着以后的$a_r$不可能再和$a_l$匹配所以$a_l$必须先和能和自己合并的段合并了,所以段会减少$1$,同理,对于$a_l+a_r\ge k$的情况,能和$a_r$匹配的数还有很多,所以我们给他新建一个段,这样构造完了之后会发现错了,而且都是错在有重复元素的位置,现在再重新考虑我们构造数列的方式,是每次往这个数列中加入一个数,那么相同数的可插入的位置肯定会重,枚举排列的话就是多乘了相同数个数的阶乘,除掉就好了。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <bits/stdc++.h>
#define LL long long
#define PII pair<LL, LL>
using namespace std;
template <class T>
inline void read(T &res)
{
res = 0;
bool flag = 0;
char c = getchar();
while (c < '0' || '9' < c) { if (c == '-') flag = 1; c = getchar(); }
while ('0' <= c && c <= '9') res = (res << 3) + (res << 1) + (c ^ 48), c = getchar();
if (flag) res = -res;
}
template <class T, class... ARC>
inline void read(T &res, ARC &...com) { read(res), read(com...); }
template <class T>
void write(T res)
{
if (res < 0) putchar('-'), res = -res;
if (res > 9) write(res / 10);
putchar(res % 10 + '0');
}
template <>
inline void write(char c) { putchar(c); }
template <>
inline void write(char *s) { while (*s) putchar(*s++); }
template <class T, class... ARC>
inline void write(T res, ARC... com) { write(res), write(com...); }
const int N = 2e5 + 5, mod = 998244353;
int n, m;
int a[N], ans;
int fac[N << 1], inv[N << 1];
inline void init()
{
fac[0] = 1;
for(int i = 1;i < (N << 1);i ++) fac[i] = 1ll * i * fac[i - 1] % mod;
inv[0] = inv[1] = 1;
for(int i = 2;i < (N << 1);i ++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
for(int i = 2;i < (N << 1);i ++) inv[i] = 1ll * inv[i - 1] * inv[i] % mod;
}
map<int, int>mp;
int main()
{
init();
read(n, m);
for(int i = 1;i <= n;i ++) read(a[i]), mp[a[i]] ++;
sort(a + 1, a + n + 1);
int l = 1, r = n, s = 1;
ans = 1;
while(l <= r)
{
if(a[l] + a[r] < m)
{
ans = 1ll * s * ans % mod;
s --;
l ++;
}
else{
ans = 1ll * s * ans % mod;
s ++;
r --;
}
}
for(int i = 1;i <= n;i ++) ans = 1ll * ans * inv[mp[a[i]]] % mod, mp[a[i]] = 0;
write(ans);
return 0;
}