Skip to content

Commit a7ab5de

Browse files
committed
typing
1 parent d1ded36 commit a7ab5de

File tree

4 files changed

+13
-21
lines changed

4 files changed

+13
-21
lines changed

malsim/mal_simulator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def from_scenario(
230230
) -> MalSimulator:
231231
"""Create a MalSimulator object from a Scenario"""
232232

233-
def register_agent_dict(agent_config: dict):
233+
def register_agent_dict(agent_config: dict[str, Any]) -> None:
234234
"""Register an agent specified in a dictionary"""
235235
logger.warning(
236236
"Having agent configs in dictionaries will be deprecated in "
@@ -245,7 +245,7 @@ def register_agent_dict(agent_config: dict):
245245
elif agent_config['type'] == AgentType.DEFENDER:
246246
sim.register_defender(agent_config['name'])
247247

248-
def register_agent_config(agent_config: AgentConfig):
248+
def register_agent_config(agent_config: AgentConfig) -> None:
249249
"""Register an agent config in simulator"""
250250
if isinstance(agent_config, AttackerAgentConfig):
251251
sim.register_attacker(

malsim/scenario.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
- attacker_class
1111
- defender_class
1212
"""
13-
13+
from __future__ import annotations
1414
import os
1515
from dataclasses import dataclass, asdict
1616
from typing import Any, Optional, TextIO
@@ -110,7 +110,7 @@ def __init__(
110110
self,
111111
lang_file: str,
112112
agents: dict[str, Any],
113-
model_dict: Optional[dict] = None,
113+
model_dict: Optional[dict[str, Any]] = None,
114114
model_file: Optional[str] = None,
115115
rewards: Optional[dict[str, Any]] = None,
116116
false_positive_rates: Optional[dict[str, Any]] = None,
@@ -157,7 +157,7 @@ def __init__(
157157
self.attack_graph, is_actionable or {}
158158
)
159159

160-
def to_dict(self):
160+
def to_dict(self) -> dict[str, Any]:
161161
assert self._lang_file, "Can not save scenario to file if lang file was not given"
162162
scenario_dict = {
163163
# 'version': ?
@@ -179,14 +179,14 @@ def to_dict(self):
179179

180180
return scenario_dict
181181

182-
def save_to_file(self, file_path):
182+
def save_to_file(self, file_path: str) -> None:
183183
save_scenario_dict(self.to_dict(), file_path)
184184

185185
@classmethod
186-
def from_dict(cls, scenario_dict):
186+
def from_dict(cls, scenario_dict: dict[str, Any]) -> Scenario:
187187
return Scenario(
188-
lang_file=scenario_dict.get('lang_file'),
189-
agents=scenario_dict.get('agents'),
188+
lang_file=scenario_dict['lang_file'],
189+
agents=scenario_dict['agents'],
190190
model_dict=scenario_dict.get('model'),
191191
model_file=scenario_dict.get('model_file'),
192192
rewards=scenario_dict.get('rewards'),
@@ -197,7 +197,7 @@ def from_dict(cls, scenario_dict):
197197
)
198198

199199
@classmethod
200-
def load_from_file(cls, scenario_file):
200+
def load_from_file(cls, scenario_file: str) -> Scenario:
201201
scenario_dict = load_scenario_dict(scenario_file)
202202
_validate_scenario_dict(scenario_dict)
203203
return cls.from_dict(scenario_dict)
@@ -439,7 +439,6 @@ def get_entry_point_nodes(
439439
def load_simulator_agents(
440440
attack_graph: AttackGraph,
441441
scenario_agents: dict[str, Any],
442-
as_dicts=False
443442
) -> list[AgentConfig]:
444443
"""Load agents to be registered in MALSimulator
445444

tests/envs/test_vectorized_obs_mal_simulator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,6 @@ def test_create_blank_observation_observability_given(
176176
elif node.model_asset and node.model_asset.name == 'User:3' and node.name in ('phishing'):
177177
assert observable
178178
else:
179-
if observable:
180-
breakpoint()
181179
assert not observable
182180

183181
def test_create_blank_observation_actionability_given(

tests/test_scenario.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from maltoolbox.attackgraph import create_attack_graph
88
from malsim.scenario import (
9+
AttackerAgentConfig,
910
apply_scenario_node_property,
1011
load_scenario,
1112
_validate_scenario_node_property_config
@@ -53,6 +54,7 @@ def test_load_scenario() -> None:
5354

5455
# Verify attacker entrypoint was added
5556
attack_step = get_node(scenario.attack_graph, 'OS App:fullAccess')
57+
assert isinstance(scenario.agents[0], AttackerAgentConfig)
5658
assert attack_step in scenario.agents[0].entry_points
5759

5860
assert isinstance(scenario.agents[0].agent, BreadthFirstAttacker)
@@ -232,7 +234,7 @@ def test_apply_scenario_observability() -> None:
232234

233235
# Apply observability rules
234236
observable = apply_scenario_node_property(
235-
scenario.attack_graph, observability_rules, default_value = 0
237+
scenario.attack_graph, observability_rules,
236238
)
237239

238240
# Make sure all attack steps are observable
@@ -261,53 +263,46 @@ def test_apply_scenario_observability_faulty() -> None:
261263
apply_scenario_node_property(
262264
scenario.attack_graph,
263265
{'NotAllowedKey': {'Data': ['read', 'write', 'delete']}},
264-
default_value = 0
265266
)
266267

267268
# Correct asset type and attack step
268269
apply_scenario_node_property(
269270
scenario.attack_graph,
270271
{'by_asset_type': { 'Application': ['read']}},
271-
default_value = 0
272272
)
273273

274274
# Wrong asset type in rule asset type to step dict
275275
with pytest.raises(AssertionError):
276276
apply_scenario_node_property(
277277
scenario.attack_graph,
278278
{'by_asset_type': {'NonExistingType': ['read']}},
279-
default_value = 0
280279
)
281280

282281
# Wrong attack step name in rule asset type to step dict
283282
with pytest.raises(AssertionError):
284283
apply_scenario_node_property(
285284
scenario.attack_graph,
286285
{'by_asset_type': {'Data': ['nonExistingAttackStep']}},
287-
default_value = 0
288286
)
289287

290288
# Correct asset name and attack step
291289
apply_scenario_node_property(
292290
scenario.attack_graph,
293291
{'by_asset_name': { 'OS App': ['read']}},
294-
default_value = 0
295292
)
296293

297294
# Wrong asset name in rule asset name to step dict
298295
with pytest.raises(AssertionError):
299296
apply_scenario_node_property(
300297
scenario.attack_graph,
301298
{'by_asset_name': { 'NonExistingName': ['read']}},
302-
default_value = 0
303299
)
304300

305301
# Wrong attack step name in rule asset name to step dict
306302
with pytest.raises(AssertionError):
307303
apply_scenario_node_property(
308304
scenario.attack_graph,
309305
{'by_asset_name': {'OS App': ['nonExistingAttackStep']}},
310-
default_value = 0
311306
)
312307

313308

0 commit comments

Comments
 (0)