Luckyleaves's Blog

stay hungry,stay foolish greedy and lucky

P3177 [HAOI2015] 树上染色

感觉比较神秘,怎么均摊下来就剩$O(n^2)$了。

感觉最近博客越欠越多了。

Solution

首先用$f_{i,j}$表示在$i$的子树中选了多少个黑点可获得的最长距离和,转移方程显然应该是:
$$
f_{x,k} = max(f_{x,k}, f_{x,k-cnt}+(f_{j,cnt}+(m-cnt)cnt+(n-m-sz[j]+cnt)(sz[j]-cnt))*w[i])
$$
贡献拆的比较显然,注意这个问题:

枚举顺序,这样的贡献方式类似于01背包,所以$k$一定是倒序枚举的,但是$cnt$如果是倒序枚举的就会出现自调用的的问题,导致自己和自己合并,要特判0,正序枚举就不存在这个问题。

关于复杂度:可以发现每一对点只会在其$lca$处被贡献一次,所以是$O(n^2)$。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <bits/stdc++.h>
#define LL long long
#define PII pair<int, int>
#define PLL pair<LL, LL>
using namespace std;
template <class T>
inline void read(T &res)
{
res = 0;
bool flag = 0;
char c = getchar();
while (c < '0' || '9' < c) { if (c == '-') flag = 1; c = getchar(); }
while ('0' <= c && c <= '9') res = (res << 3) + (res << 1) + (c ^ 48), c = getchar();
if (flag) res = -res;
}
template <class T, class... ARC>
inline void read(T &res, ARC &...com) { read(res), read(com...); }
template <class T>
void write(T res)
{
if (res < 0) putchar('-'), res = -res;
if (res > 9) write(res / 10);
putchar(res % 10 + '0');
}
template <>
inline void write(char c) { putchar(c); }
template <>
inline void write(const char *s) { while (*s) putchar(*s++); }
template <class T, class... ARC>
inline void write(T res, ARC... com) { write(res), write(com...); }
const int N = 2005;
int n, m;
int idx;
int e[N << 1], ne[N << 1], h[N], w[N << 1];
inline void add(int x, int y, int z)
{
idx ++;
e[idx] = y, ne[idx] = h[x], h[x] = idx, w[idx] = z;
}
int sz[N];
LL f[N][N]; // 用了多少黑点
inline void dfs(int x, int y)
{
sz[x] = 1;
f[x][0] = f[x][1] = 0;
for(int i = h[x], j; ~i;i = ne[i])
{
j = e[i];
if(j == y) continue;
dfs(j, x);
sz[x] += sz[j];
for(int k = min(sz[x], m);k >= 0;k --)
{
for(int cnt = 0;cnt <= min(k, sz[j]);cnt ++)
{
if(f[x][k - cnt] == -1 || f[j][cnt] == -1) continue;
f[x][k] = max(f[x][k], f[x][k - cnt] + f[j][cnt] + 1ll * (m - cnt) * cnt * w[i] + 1ll * (n - m - sz[j] + cnt) * (sz[j] - cnt) * w[i]);
}
}
}
}
int main()
{
memset(h, -1, sizeof(h));
memset(f, -1, sizeof(f));
read(n, m);
m = min(n, n - m);
for(int i = 1, o, u, p;i < n;i ++) read(o, u, p), add(o, u, p), add(u, o, p);
dfs(1, 1);
write(f[1][m]);
return 0;
}

后记

然后尼玛这个做法是假的,树上背包必须先计算,后再加$sz$。(早年的题就是水啊)

New Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <bits/stdc++.h>
#define LL long long
#define PII pair<int, int>
#define PLL pair<LL, LL>
using namespace std;
template <class T>
inline void read(T &res)
{
res = 0;
bool flag = 0;
char c = getchar();
while (c < '0' || '9' < c) { if (c == '-') flag = 1; c = getchar(); }
while ('0' <= c && c <= '9') res = (res << 3) + (res << 1) + (c ^ 48), c = getchar();
if (flag) res = -res;
}
template <class T, class... ARC>
inline void read(T &res, ARC &...com) { read(res), read(com...); }
template <class T>
void write(T res)
{
if (res < 0) putchar('-'), res = -res;
if (res > 9) write(res / 10);
putchar(res % 10 + '0');
}
template <>
inline void write(char c) { putchar(c); }
template <>
inline void write(const char *s) { while (*s) putchar(*s++); }
template <class T, class... ARC>
inline void write(T res, ARC... com) { write(res), write(com...); }
const int N = 2005;
int n, m;
int idx;
int e[N << 1], ne[N << 1], h[N], w[N << 1];
inline void add(int x, int y, int z)
{
idx ++;
e[idx] = y, ne[idx] = h[x], h[x] = idx, w[idx] = z;
}
int sz[N];
LL f[N][N]; // 用了多少黑点
inline void dfs(int x, int y)
{
sz[x] = 1;
f[x][0] = f[x][1] = 0;
static LL d[N];
for(int i = h[x], j; ~i;i = ne[i])
{
j = e[i];
if(j == y) continue;
dfs(j, x);
for(int k = 0;k <= sz[j] + sz[x];k ++) d[k] = 0;
for(int k = 0;k <= sz[x];k ++)
for(int cnt = 0;cnt <= sz[j];cnt ++)
d[k + cnt] = max(d[k + cnt], f[x][k] + f[j][cnt] + 1ll * (m - cnt) * cnt * w[i] + 1ll * (n - m - sz[j] + cnt) * (sz[j] - cnt) * w[i]);
sz[x] += sz[j];
for(int k = 0;k <= sz[x];k ++) f[x][k] = d[k];
}
}
int main()
{
memset(h, -1, sizeof(h));
memset(f, 0x3f, sizeof(f));
read(n, m);
m = min(n, n - m);
for(int i = 1, o, u, p;i < n;i ++) read(o, u, p), add(o, u, p), add(u, o, p);
dfs(1, 1);
write(f[1][m]);
return 0;
}