|
| 1 | +# https://leetcode.com/problems/maximize-sum-of-weights-after-edge-removals/ |
| 2 | +# Tricky tree DP |
| 3 | +# Maximum the sum at the end means minimize the removed sum |
| 4 | + |
| 5 | +def maximizeSumOfWeights(edges: list[list[int]], k: int) -> int: |
| 6 | + inf = 1 << 60 |
| 7 | + n = len(edges) + 1 |
| 8 | + graph = [[] for _ in range(n + 1)] |
| 9 | + for a, b, c in edges: |
| 10 | + graph[a].append((b, c)) |
| 11 | + graph[b].append((a, c)) |
| 12 | + |
| 13 | + def compute_cost(arr, k): |
| 14 | + # we can take at most `k` from `a` |
| 15 | + # greedily minimize cost, we take the `a` with the biggest difference compared to `b` |
| 16 | + arr.sort(key=lambda x: x[1] - x[0], reverse=True) |
| 17 | + return sum(a for a, b in arr[:k]) + sum(b for a, b in arr[k:]) |
| 18 | + |
| 19 | + def solve(cur, prev): |
| 20 | + if len(graph[cur]) == 1 and cur != 0: # base case: leaf |
| 21 | + return 0, graph[cur][0][1] |
| 22 | + |
| 23 | + options = [] |
| 24 | + par_val = inf # cost to remove edge to parent node |
| 25 | + for adj, val in graph[cur]: |
| 26 | + if adj == prev: |
| 27 | + par_val = val |
| 28 | + continue |
| 29 | + c1, c2 = solve(adj, cur) |
| 30 | + options.append((c1, c2)) |
| 31 | + |
| 32 | + # calculate cheapest if we keep the parent edge |
| 33 | + if cur == 0: |
| 34 | + # at root, no parent edge to consider |
| 35 | + cost1 = compute_cost(options, k) |
| 36 | + else: |
| 37 | + # any other node, `k-1` to keep parent edge |
| 38 | + cost1 = compute_cost(options, k - 1) |
| 39 | + |
| 40 | + # calculate cheapest with parent edge removed |
| 41 | + cost2 = par_val + compute_cost(options, k) |
| 42 | + |
| 43 | + # return (overall cheapest, cheapest if we remove the parent edge) |
| 44 | + return min(cost1, cost2), cost2 |
| 45 | + |
| 46 | + return sum(i for _, _, i in edges) - solve(0, -1)[0] |
| 47 | + |
| 48 | + |
| 49 | +print(maximizeSumOfWeights(edges=[[0, 1, 4], [0, 2, 2], [2, 3, 12], [2, 4, 6]], k=2)) # 22 |
| 50 | +print(maximizeSumOfWeights(edges=[[0, 1, 5], [1, 2, 10], [0, 3, 15], [3, 4, 20], [3, 5, 5], [0, 6, 10]], k=3)) # 65 |
| 51 | + |
| 52 | +print(maximizeSumOfWeights([[0, 1, 25], [0, 2, 10], [1, 3, 29]], 1)) # 39 |
| 53 | +print(maximizeSumOfWeights([[0, 1, 34], [0, 2, 17]], 1)) # 34 |
0 commit comments