Skip to content

Commit ef74fa1

Browse files
committed
upsolve
1 parent f0c8885 commit ef74fa1

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# https://leetcode.com/problems/maximize-sum-of-weights-after-edge-removals/
2+
# Tricky tree DP
3+
# Maximum the sum at the end means minimize the removed sum
4+
5+
def maximizeSumOfWeights(edges: list[list[int]], k: int) -> int:
6+
inf = 1 << 60
7+
n = len(edges) + 1
8+
graph = [[] for _ in range(n + 1)]
9+
for a, b, c in edges:
10+
graph[a].append((b, c))
11+
graph[b].append((a, c))
12+
13+
def compute_cost(arr, k):
14+
# we can take at most `k` from `a`
15+
# greedily minimize cost, we take the `a` with the biggest difference compared to `b`
16+
arr.sort(key=lambda x: x[1] - x[0], reverse=True)
17+
return sum(a for a, b in arr[:k]) + sum(b for a, b in arr[k:])
18+
19+
def solve(cur, prev):
20+
if len(graph[cur]) == 1 and cur != 0: # base case: leaf
21+
return 0, graph[cur][0][1]
22+
23+
options = []
24+
par_val = inf # cost to remove edge to parent node
25+
for adj, val in graph[cur]:
26+
if adj == prev:
27+
par_val = val
28+
continue
29+
c1, c2 = solve(adj, cur)
30+
options.append((c1, c2))
31+
32+
# calculate cheapest if we keep the parent edge
33+
if cur == 0:
34+
# at root, no parent edge to consider
35+
cost1 = compute_cost(options, k)
36+
else:
37+
# any other node, `k-1` to keep parent edge
38+
cost1 = compute_cost(options, k - 1)
39+
40+
# calculate cheapest with parent edge removed
41+
cost2 = par_val + compute_cost(options, k)
42+
43+
# return (overall cheapest, cheapest if we remove the parent edge)
44+
return min(cost1, cost2), cost2
45+
46+
return sum(i for _, _, i in edges) - solve(0, -1)[0]
47+
48+
49+
print(maximizeSumOfWeights(edges=[[0, 1, 4], [0, 2, 2], [2, 3, 12], [2, 4, 6]], k=2)) # 22
50+
print(maximizeSumOfWeights(edges=[[0, 1, 5], [1, 2, 10], [0, 3, 15], [3, 4, 20], [3, 5, 5], [0, 6, 10]], k=3)) # 65
51+
52+
print(maximizeSumOfWeights([[0, 1, 25], [0, 2, 10], [1, 3, 29]], 1)) # 39
53+
print(maximizeSumOfWeights([[0, 1, 34], [0, 2, 17]], 1)) # 34

0 commit comments

Comments
 (0)