Skip to content

Commit dd593e7

Browse files
authored
Merge pull request #224 from raspersc2/fix/torches-expansion-locations
Fix TorchesAIE expansion locations
2 parents a747fc6 + a6756d9 commit dd593e7

File tree

7 files changed

+192
-36
lines changed

7 files changed

+192
-36
lines changed

sc2/bot_ai.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,13 @@ def expansion_locations_dict(self) -> dict[Point2, Units]:
171171
expansion_locations: dict[Point2, Units] = {pos: Units([], self) for pos in self._expansion_positions_list}
172172
for resource in self.resources:
173173
# It may be that some resources are not mapped to an expansion location
174-
exp_position: Point2 | None = self._resource_location_to_expansion_position_dict.get(
174+
exp_positions: set[Point2] | None = self._resource_location_to_expansion_position_dict.get(
175175
resource.position, None
176176
)
177-
if exp_position:
178-
assert exp_position in expansion_locations
179-
expansion_locations[exp_position].append(resource)
177+
if exp_positions:
178+
for exp_position in exp_positions:
179+
assert exp_position in expansion_locations
180+
expansion_locations[exp_position].append(resource)
180181
return expansion_locations
181182

182183
@property

sc2/bot_ai_internal.py

Lines changed: 180 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _initialize_variables(self) -> None:
119119
self._all_units_previous_map: dict[int, Unit] = {}
120120
self._previous_upgrades: set[UpgradeId] = set()
121121
self._expansion_positions_list: list[Point2] = []
122-
self._resource_location_to_expansion_position_dict: dict[Point2, Point2] = {}
122+
self._resource_location_to_expansion_position_dict: dict[Point2, set[Point2]] = {}
123123
self._time_before_step: float = 0
124124
self._time_after_step: float = 0
125125
self._min_step_time: float = math.inf
@@ -177,14 +177,154 @@ def expansion_locations(self) -> dict[Point2, Units]:
177177
)
178178
return self.expansion_locations_dict
179179

180+
def _cluster_center(self, group: list[Unit]) -> Point2:
181+
"""
182+
Calculates the geometric center (centroid) of a given group of units.
183+
184+
Parameters:
185+
group: A list of Unit objects representing the group of units for
186+
which the center is to be calculated.
187+
188+
Raises:
189+
ValueError: If the provided group is empty.
190+
191+
Returns:
192+
Point2: The calculated centroid of the group as a Point2 object.
193+
"""
194+
if not group:
195+
raise ValueError("Cannot calculate center of empty group")
196+
197+
total_x = total_y = 0
198+
for unit in group:
199+
total_x += unit.position.x
200+
total_y += unit.position.y
201+
202+
count = len(group)
203+
return Point2((total_x / count, total_y / count))
204+
205+
def _find_expansion_location(
206+
self, resources: Units | list[Unit], amount: int, offsets: list[tuple[float, float]]
207+
) -> Point2:
208+
"""
209+
Finds the most suitable expansion location for resources.
210+
211+
Parameters:
212+
resources: The list of resource entities or units near which the
213+
expansion location needs to be found.
214+
amount: The total number of resource entities or units to consider.
215+
offsets (list[tuple[float, float]): A list of coordinate pairs denoting position
216+
offsets to consider around the center of resources.
217+
218+
Returns:
219+
The calculated optimal expansion Point2 if a suitable position is found;
220+
otherwise, None.
221+
"""
222+
# Normal single expansion logic for regular bases
223+
# Calculate center, round and add 0.5 because expansion location will have (x.5, y.5)
224+
# coordinates because bases have size 5.
225+
center_x = int(sum(resource.position.x for resource in resources) / amount) + 0.5
226+
center_y = int(sum(resource.position.y for resource in resources) / amount) + 0.5
227+
possible_points = (Point2((offset[0] + center_x, offset[1] + center_y)) for offset in offsets)
228+
# Filter out points that are too near
229+
possible_points = [
230+
point
231+
for point in possible_points
232+
# Check if point can be built on
233+
if self.game_info.placement_grid[point.rounded] == 1
234+
# Check if all resources have enough space to point
235+
and all(
236+
point.distance_to(resource) >= (7 if resource._proto.unit_type in geyser_ids else 6)
237+
for resource in resources
238+
)
239+
]
240+
# Choose best fitting point
241+
result: Point2 = min(
242+
possible_points, key=lambda point: sum(point.distance_to(resource_) for resource_ in resources)
243+
)
244+
return result
245+
246+
def _has_opposite_side_geyser_layout(self, minerals: list[Unit], gas_geysers: list[Unit]) -> bool:
247+
"""
248+
Determines whether the gas geysers have an opposite-side mineral line layout.
249+
250+
The method evaluates if two gas geysers are located on opposite sides of a
251+
mineral line.
252+
If this returns True we consider this location has 2 valid expansion locations
253+
either side of the mineral line.
254+
255+
Parameters:
256+
minerals:
257+
A list of mineral fields at this location.
258+
gas_geysers : list[Unit]
259+
A list of gas geysers at this location.
260+
261+
Returns:
262+
bool
263+
True if the geysers fulfill the opposite-side layout condition with
264+
respect to the mineral line, otherwise False.
265+
"""
266+
# Need exactly 2 geysers and enough minerals for a line
267+
if len(gas_geysers) != 2 or len(minerals) < 6:
268+
return False
269+
270+
# Find the two minerals that are furthest apart
271+
max_distance: float = 0.0
272+
mineral_1: Unit = minerals[0]
273+
mineral_2: Unit = minerals[1]
274+
275+
for i, m1 in enumerate(minerals):
276+
for m2 in minerals[i + 1 :]:
277+
distance = m1.distance_to(m2)
278+
if distance > max_distance:
279+
max_distance = distance
280+
mineral_1 = m1
281+
mineral_2 = m2
282+
283+
# ensure line is long enough
284+
if max_distance < 4:
285+
return False
286+
287+
# Create line from the two furthest minerals
288+
x1, y1 = mineral_1.position.x, mineral_1.position.y
289+
x2, y2 = mineral_2.position.x, mineral_2.position.y
290+
291+
geyser_1, geyser_2 = gas_geysers
292+
293+
# Check if the mineral line is more vertical than horizontal
294+
if abs(x2 - x1) < 0.1:
295+
# Vertical line: use x-coordinate to determine sides
296+
line_x = (x1 + x2) / 2
297+
298+
side_1 = geyser_1.position.x - line_x
299+
side_2 = geyser_2.position.x - line_x
300+
301+
# Must be on opposite sides and far enough from the line
302+
return side_1 * side_2 < 0 and abs(side_1) > 3 and abs(side_2) > 3
303+
304+
# Calculate line equation: y = mx + b
305+
slope = (y2 - y1) / (x2 - x1)
306+
intercept = y1 - slope * x1
307+
308+
# Function to determine which side of the line a point is on
309+
def side_of_line(point: Point2) -> float:
310+
return point.y - slope * point.x - intercept
311+
312+
side_1 = side_of_line(geyser_1.position)
313+
side_2 = side_of_line(geyser_2.position)
314+
315+
# Check if geysers are on opposite sides
316+
opposite_sides = side_1 * side_2 < 0
317+
318+
return opposite_sides
319+
180320
@final
181321
def _find_expansion_locations(self) -> None:
182322
"""Ran once at the start of the game to calculate expansion locations."""
183323
# Idea: create a group for every resource, then merge these groups if
184324
# any resource in a group is closer than a threshold to any resource of another group
185325

186326
# Distance we group resources by
187-
resource_spread_threshold: float = 8.5
327+
resource_spread_threshold: float = 10.5
188328
# Create a group for every resource
189329
resource_groups: list[list[Unit]] = [
190330
[resource]
@@ -200,22 +340,23 @@ def _find_expansion_locations(self) -> None:
200340
for group_a, group_b in itertools.combinations(resource_groups, 2):
201341
# Check if any pair of resource of these groups is closer than threshold together
202342
# And that they are on the same terrain level
203-
if any(
204-
resource_a.distance_to(resource_b) <= resource_spread_threshold
205-
# check if terrain height measurement at resources is within 10 units
206-
# this is since some older maps have inconsistent terrain height
207-
# tiles at certain expansion locations
208-
and abs(height_grid[resource_a.position.rounded] - height_grid[resource_b.position.rounded]) <= 10
209-
for resource_a, resource_b in itertools.product(group_a, group_b)
343+
center_a = self._cluster_center(group_a)
344+
center_b = self._cluster_center(group_b)
345+
346+
if center_a.distance_to(center_b) <= resource_spread_threshold and all(
347+
abs(height_grid[res_a.position.rounded] - height_grid[res_b.position.rounded]) <= 10
348+
for res_a in group_a
349+
for res_b in group_b
210350
):
211351
# Remove the single groups and add the merged group
212352
resource_groups.remove(group_a)
213353
resource_groups.remove(group_b)
214354
resource_groups.append(group_a + group_b)
215355
merged_group = True
216356
break
357+
217358
# Distance offsets we apply to center of each resource group to find expansion position
218-
offset_range = 7
359+
offset_range: int = 7
219360
offsets = [
220361
(x, y)
221362
for x, y in itertools.product(range(-offset_range, offset_range + 1), repeat=2)
@@ -227,33 +368,41 @@ def _find_expansion_locations(self) -> None:
227368
for resources in resource_groups:
228369
# Possible expansion points
229370
amount = len(resources)
230-
# Calculate center, round and add 0.5 because expansion location will have (x.5, y.5)
231-
# coordinates because bases have size 5.
232-
center_x = int(sum(resource.position.x for resource in resources) / amount) + 0.5
233-
center_y = int(sum(resource.position.y for resource in resources) / amount) + 0.5
234-
possible_points = (Point2((offset[0] + center_x, offset[1] + center_y)) for offset in offsets)
235-
# Filter out points that are too near
236-
possible_points = (
237-
point
238-
for point in possible_points
239-
# Check if point can be built on
240-
if self.game_info.placement_grid[point.rounded] == 1
241-
# Check if all resources have enough space to point
242-
and all(
243-
point.distance_to(resource) >= (7 if resource._proto.unit_type in geyser_ids else 6)
244-
for resource in resources
245-
)
246-
)
371+
# this check is needed for TorchesAIE where the gold mineral wall has a
372+
# unit type of `RichMineralField` so we can only filter out by amount of resources
373+
if amount > 12:
374+
continue
375+
376+
minerals = [r for r in resources if r._proto.unit_type not in geyser_ids]
377+
gas_geysers = [r for r in resources if r._proto.unit_type in geyser_ids]
378+
379+
# Check if we have exactly 2 gas geysers positioned above/below the mineral line
380+
# Needed for TorchesAIE where one gold base has 2 expansion locations
381+
if self._has_opposite_side_geyser_layout(minerals, gas_geysers):
382+
# Create expansion locations for each geyser + minerals
383+
for geyser in gas_geysers:
384+
local_resources = minerals + [geyser]
385+
result: Point2 = self._find_expansion_location(local_resources, len(local_resources), offsets)
386+
centers[result] = local_resources
387+
# Put all expansion locations in a list
388+
self._expansion_positions_list.append(result)
389+
# Maps all resource positions to the expansion position
390+
for resource in local_resources:
391+
if resource.position in self._resource_location_to_expansion_position_dict:
392+
self._resource_location_to_expansion_position_dict[resource.position].add(result)
393+
else:
394+
self._resource_location_to_expansion_position_dict[resource.position] = {result}
395+
396+
continue
397+
247398
# Choose best fitting point
248-
result: Point2 = min(
249-
possible_points, key=lambda point: sum(point.distance_to(resource_) for resource_ in resources)
250-
)
399+
result: Point2 = self._find_expansion_location(resources, amount, offsets)
251400
centers[result] = resources
252401
# Put all expansion locations in a list
253402
self._expansion_positions_list.append(result)
254403
# Maps all resource positions to the expansion position
255404
for resource in resources:
256-
self._resource_location_to_expansion_position_dict[resource.position] = result
405+
self._resource_location_to_expansion_position_dict[resource.position] = {result}
257406

258407
@final
259408
def _correct_zerg_supply(self) -> None:

test/generate_pickle_files_bot.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,13 @@ def main():
175175
"OxideAIE",
176176
"PaladinoTerminalLE",
177177
"ParaSiteLE",
178+
"PersephoneAIE",
178179
"PillarsofGold506",
179180
"PillarsofGoldLE",
180181
"PortAleksanderLE",
181182
"PrimusQ9",
182183
"ProximaStationLE",
184+
"PylonAIE",
183185
"RedshiftLE",
184186
"Reminiscence",
185187
"RomanticideAIE",
@@ -193,6 +195,7 @@ def main():
193195
"StasisLE",
194196
"TheTimelessVoid",
195197
"ThunderbirdLE",
198+
"TorchesAIE",
196199
"Treachery",
197200
"Triton",
198201
"Urzagol",

test/pickle_data/PersephoneAIE.xz

53.8 KB
Binary file not shown.

test/pickle_data/PylonAIE.xz

54.2 KB
Binary file not shown.

test/pickle_data/TorchesAIE.xz

53.6 KB
Binary file not shown.

test/test_pickled_ramp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from loguru import logger
1515

1616
from sc2.game_info import Ramp
17+
from sc2.ids.unit_typeid import UnitTypeId
1718
from sc2.position import Point2
1819
from sc2.unit import Unit
1920
from sc2.units import Units
@@ -36,6 +37,8 @@ class TestClass:
3637
# Load all pickle files and convert them into bot objects from raw data (game_data, game_info, game_state)
3738
scenarios = [(map_path.name, {"map_path": map_path}) for map_path in MAPS]
3839

40+
MAPS_WITH_ODD_EXPANSION_COUNT: set[UnitTypeId] = {"Persephone AIE", "StargazersAIE", "Stasis LE"}
41+
3942
def test_main_base_ramp(self, map_path: Path):
4043
bot = get_map_specific_bot(map_path)
4144
# pyre-ignore[16]
@@ -105,7 +108,7 @@ def test_bot_ai(self, map_path: Path):
105108
# On N player maps, it is expected that there are N*X bases because of symmetry, at least for maps designed for 1vs1
106109
# Those maps in the list have an un-even expansion count
107110
# pyre-ignore[16]
108-
expect_even_expansion_count = 1 if bot.game_info.map_name in ["StargazersAIE", "Stasis LE"] else 0
111+
expect_even_expansion_count = 1 if bot.game_info.map_name in self.MAPS_WITH_ODD_EXPANSION_COUNT else 0
109112
assert (
110113
len(bot.expansion_locations_list) % (len(bot.enemy_start_locations) + 1) == expect_even_expansion_count
111114
), f"{bot.expansion_locations_list}"

0 commit comments

Comments
 (0)