1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
| #include <bits/stdc++.h> #define LL long long using namespace std; const int N = 3e4 + 5; int n, k; int idx; int e[N << 1], ne[N << 1], h[N]; inline void add(int x, int y) { idx ++; e[idx] = y, ne[idx] = h[x], h[x] = idx; } int sum[N]; void dfs(int x, int y) { sum[x] = 1; for(int i = h[x], j; ~i;i = ne[i]) { j = e[i]; if(j == y) continue; dfs(j, x); sum[x] += sum[j]; } } LL f[N][205], ans; void dfs1(int x, int y) { for(int i = h[x], j; ~i;i = ne[i]) { j = e[i]; if(j == y) continue; f[x][0] += 1ll * (n - sum[j]) * sum[j]; dfs1(j, x); for(int l = 1;l <= k;l ++) f[x][l] += f[j][l - 1]; } } void dfs2(int x, int y) { LL res = 0; for(int i = 0;i <= k;i ++) res += f[x][i]; ans = max(ans, res); for(int i = h[x], j; ~i;i = ne[i]) { j = e[i]; if(j == y) continue; for(int l = k;l >= 1;l --) { f[j][l] += f[x][l - 1]; if(l > 1) f[j][l] -= f[j][l - 2]; } f[j][0] += 1ll * (n - sum[j]) * sum[j]; f[j][1] -= 1ll * (n - sum[j]) * sum[j]; dfs2(j, x); } } int main() { memset(h, -1, sizeof(h)); scanf("%d%d", &n, &k); for(int i = 1, o, u;i < n;i ++) scanf("%d%d", &o, &u), add(o, u), add(u, o); dfs(1, 1); dfs1(1, 1); dfs2(1, 1); printf("%lld", ans); return 0; }
|