Skip to content

Commit 06b4d58

Browse files
committed
Go back to having agent configs in dicts, backwards compatible
1 parent e34eee0 commit 06b4d58

File tree

7 files changed

+81
-84
lines changed

7 files changed

+81
-84
lines changed

malsim/agents/defenders/heuristic_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_next_action(
3434

3535
"""Return an action that disables a compromised node"""
3636

37-
self.compromised_nodes |= agent_state.step_all_compromised_nodes
37+
self.compromised_nodes |= agent_state.step_compromised_nodes
3838

3939
selected_node_cost = math.inf
4040
selected_node = None
@@ -92,7 +92,7 @@ def get_next_action(
9292

9393
"""Return an action that disables a compromised node"""
9494

95-
self.compromised_nodes |= agent_state.step_all_compromised_nodes
95+
self.compromised_nodes |= agent_state.step_compromised_nodes
9696

9797
selected_node_cost = math.inf
9898
selected_node = None

malsim/envs/gym_envs.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414

1515
from ..scenario import load_scenario
16-
from ..mal_simulator import MalSimulator, AgentConfig, AttackerAgentConfig
16+
from ..mal_simulator import MalSimulator, AgentType
1717
from ..envs import MalSimVectorizedObsEnv
1818
from ..agents import DecisionAgent
1919

@@ -36,7 +36,7 @@ def __init__(self, scenario_file: str, **kwargs: Any) -> None:
3636

3737
attacker_agents = [
3838
agent for agent in scenario.agents
39-
if isinstance(agent, AttackerAgentConfig)
39+
if agent['type'] == AgentType.ATTACKER
4040
]
4141

4242
assert len(attacker_agents) == 1, (
@@ -45,11 +45,11 @@ def __init__(self, scenario_file: str, **kwargs: Any) -> None:
4545
)
4646

4747
attacker_agent = attacker_agents[0]
48-
self.attacker_agent_name = attacker_agent.name
48+
self.attacker_agent_name = attacker_agent['name']
4949

5050
self.sim.register_attacker(
5151
self.attacker_agent_name,
52-
attacker_agent.entry_points
52+
attacker_agent['entry_points']
5353
)
5454
self.sim.reset()
5555

@@ -127,26 +127,26 @@ def __init__(self, scenario_file: str, **kwargs: Any) -> None:
127127
self.action_space = \
128128
self.sim.action_space(self.defender_agent_name)
129129

130-
def _register_attacker_agents(self, agents: list[AgentConfig]) -> None:
130+
def _register_attacker_agents(self, agents: list[dict[str, Any]]) -> None:
131131
"""Register attackers in simulator"""
132132
for agent_config in agents:
133-
if isinstance(agent_config, AttackerAgentConfig):
133+
if agent_config['type'] == AgentType.ATTACKER:
134134
self.sim.register_attacker(
135-
agent_config.name, agent_config.entry_points
135+
agent_config['name'], agent_config['entry_points']
136136
)
137137

138138
def _create_attacker_decision_agents(
139-
self, agents: list[AgentConfig], seed: Optional[int] = None
139+
self, agents: list[dict[str, Any]], seed: Optional[int] = None
140140
) -> dict[str, DecisionAgent]:
141141
"""Create decision agents for each attacker"""
142142

143143
attacker_agents = {}
144144
for agent_config in agents:
145-
if isinstance(agent_config, AttackerAgentConfig):
146-
agent_name = agent_config.name
147-
if agent_config.agent_class:
145+
if agent_config['type'] == AgentType.ATTACKER:
146+
agent_name = agent_config['name']
147+
if agent_config['agent_class']:
148148
attacker_agents[agent_name] = (
149-
agent_config.agent_class(
149+
agent_config['agent_class'](
150150
{'seed': seed, 'randomize': self.randomize}
151151
)
152152
)

malsim/mal_simulator.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class MalSimDefenderState(MalSimAgentState):
9696
@property
9797
def step_all_compromised_nodes(self) -> frozenset[AttackGraphNode]:
9898
print(
99-
"Deprecated in mal-simulator 1.1.0, "
99+
"'step_all_compromised_nodes' deprecated in mal-simulator 1.1.0, "
100100
"please use 'step_compromised_nodes'"
101101
)
102102
return self.step_compromised_nodes
@@ -1038,7 +1038,7 @@ def _defender_step_reward(
10381038
- reward_mode: which way to calculate reward
10391039
"""
10401040
step_enabled_defenses = defender_state.step_performed_nodes
1041-
step_compromised_nodes = defender_state.step_all_compromised_nodes
1041+
step_compromised_nodes = defender_state.step_compromised_nodes
10421042

10431043
# Defender is penalized for compromised steps and enabled defenses
10441044
step_reward = - sum(
@@ -1112,7 +1112,7 @@ def step(
11121112
self.recording[self.cur_iter] = {}
11131113

11141114
# Populate these from the results for all agents' actions.
1115-
step_all_compromised_nodes: set[AttackGraphNode] = set()
1115+
step_compromised_nodes: set[AttackGraphNode] = set()
11161116
step_enabled_defenses: set[AttackGraphNode] = set()
11171117
step_nodes_made_unviable: set[AttackGraphNode] = set()
11181118

@@ -1130,7 +1130,7 @@ def step(
11301130
agent_compromised, agent_attempted = self._attacker_step(
11311131
attacker_state, actions.get(attacker_state.name, [])
11321132
)
1133-
step_all_compromised_nodes |= agent_compromised
1133+
step_compromised_nodes |= agent_compromised
11341134
self.recording[self.cur_iter][attacker_state.name] = (
11351135
list(agent_compromised)
11361136
)
@@ -1160,7 +1160,7 @@ def step(
11601160
# Update defender state
11611161
updated_defender_state = self._update_defender_state(
11621162
agent_state,
1163-
step_all_compromised_nodes,
1163+
step_compromised_nodes,
11641164
step_enabled_defenses,
11651165
step_nodes_made_unviable
11661166
)
@@ -1189,14 +1189,14 @@ def render(self) -> None:
11891189

11901190

11911191
def run_simulation(
1192-
sim: MalSimulator, agents: list[AgentConfig]
1192+
sim: MalSimulator, agents: list[dict[str, Any]]
11931193
) -> dict[str, list[AttackGraphNode]]:
11941194
"""Run a simulation with agents
11951195
11961196
Return selected actions by each agent in each step
11971197
"""
11981198
agent_actions: dict[str, list[AttackGraphNode]] = {}
1199-
total_rewards = {agent_config.name: 0.0 for agent_config in agents}
1199+
total_rewards = {agent_config['name']: 0.0 for agent_config in agents}
12001200

12011201
logger.info("Starting CLI env simulator.")
12021202
states = sim.reset()
@@ -1207,8 +1207,8 @@ def run_simulation(
12071207

12081208
# Select actions for each agent
12091209
for agent_config in agents:
1210-
decision_agent: Optional[DecisionAgent] = agent_config.agent
1211-
agent_name = agent_config.name
1210+
decision_agent: Optional[DecisionAgent] = agent_config['agent']
1211+
agent_name = agent_config['name']
12121212
if decision_agent is None:
12131213
print(
12141214
f'Agent "{agent_name}" has no decision agent class '
@@ -1232,14 +1232,15 @@ def run_simulation(
12321232
# Perform next step of simulation
12331233
states = sim.step(actions)
12341234
for agent_config in agents:
1235-
total_rewards[agent_config.name] += sim.agent_reward(agent_config.name)
1235+
total_rewards[agent_config['name']] += sim.agent_reward(agent_config['name'])
12361236

12371237
print("---")
12381238

12391239
print(f"Simulation over after {sim.cur_iter} steps.")
12401240

12411241
# Print total rewards
12421242
for agent_config in agents:
1243-
print(f'Total reward "{agent_config.name}"', total_rewards[agent_config.name])
1243+
agent_name = agent_config['name']
1244+
print(f'Total reward "{agent_name}"', total_rewards[agent_config['name']])
12441245

12451246
return agent_actions

malsim/scenario.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313
from __future__ import annotations
1414
import os
15-
from dataclasses import dataclass, asdict
15+
from dataclasses import dataclass
1616
from typing import Any, Optional, TextIO
1717
from enum import Enum
1818
import logging
@@ -83,6 +83,7 @@ class AgentType(Enum):
8383

8484
@dataclass
8585
class AgentConfig:
86+
# Will be used for agents in the future instead of dicts
8687
name: str
8788
agent_class: Any
8889
agent: Any
@@ -151,20 +152,18 @@ def __init__(
151152
self.attack_graph, false_negative_rates or {}
152153
)
153154
self.is_observable = apply_scenario_node_property(
154-
self.attack_graph, is_observable or {}
155+
self.attack_graph, is_observable or {}, default_value=False
155156
)
156157
self.is_actionable = apply_scenario_node_property(
157-
self.attack_graph, is_actionable or {}
158+
self.attack_graph, is_actionable or {}, default_value=False
158159
)
159160

160161
def to_dict(self) -> dict[str, Any]:
161162
assert self._lang_file, "Can not save scenario to file if lang file was not given"
162163
scenario_dict = {
163164
# 'version': ?
164165
'lang_file': self._lang_file,
165-
'agents': {
166-
a.name: asdict(a) for a in self.agents
167-
},
166+
'agents': self.agents,
168167
'rewards': {},
169168
'false_positive_rates': {},
170169
'false_negative_rates': {},
@@ -437,9 +436,8 @@ def get_entry_point_nodes(
437436

438437

439438
def load_simulator_agents(
440-
attack_graph: AttackGraph,
441-
scenario_agents: dict[str, Any],
442-
) -> list[AgentConfig]:
439+
attack_graph: AttackGraph, scenario_agents: dict[str, Any]
440+
) -> list[dict[str, Any]]:
443441
"""Load agents to be registered in MALSimulator
444442
445443
Create the agents from the specified classes,
@@ -449,7 +447,7 @@ def load_simulator_agents(
449447
- attack_graph: the attack graph
450448
- scenario: the scenario in question as a dict
451449
Return:
452-
- agents: a list of agent configurations
450+
- agents: a list of agent configurations (dicts)
453451
"""
454452

455453
# Create list of agents dicts
@@ -470,29 +468,29 @@ def load_simulator_agents(
470468
)
471469

472470
if agent_type == AgentType.ATTACKER:
473-
agent_config = AttackerAgentConfig(
474-
name=agent_name,
475-
agent_class=agent_class,
476-
agent=agent,
477-
policy=policy,
478-
config=agent_config,
479-
entry_points=get_entry_point_nodes(
471+
agent_config = {
472+
'name': agent_name,
473+
'agent_class': agent_class,
474+
'agent': agent,
475+
'policy': policy,
476+
'config': agent_config,
477+
'entry_points': get_entry_point_nodes(
480478
attack_graph, agent_info['entry_points'] # Required
481479
),
482-
goals=get_entry_point_nodes(
480+
'goals': get_entry_point_nodes(
483481
attack_graph, agent_info.get('goals', []) # Optional
484482
),
485-
type=AgentType.ATTACKER
486-
)
483+
'type': AgentType.ATTACKER
484+
}
487485
elif agent_type == AgentType.DEFENDER:
488-
agent_config = DefenderAgentConfig(
489-
name=agent_name,
490-
agent_class=agent_class,
491-
agent=agent,
492-
policy=policy,
493-
config=agent_config,
494-
type=AgentType.DEFENDER
495-
)
486+
agent_config = {
487+
'name': agent_name,
488+
'agent_class': agent_class,
489+
'agent': agent,
490+
'policy': policy,
491+
'config': agent_config,
492+
'type':AgentType.DEFENDER
493+
}
496494

497495
agents.append(agent_config)
498496

tests/envs/test_example_scenarios.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ def test_bfs_vs_bfs_state_and_reward() -> None:
4141
attacker_agent_name = "attacker1"
4242

4343
attacker_agent = next(
44-
agent.agent for agent in scenario.agents
45-
if agent.name == attacker_agent_name
44+
agent['agent'] for agent in scenario.agents
45+
if agent['name'] == attacker_agent_name
4646
)
4747
defender_agent = next(
48-
agent.agent for agent in scenario.agents
49-
if agent.name == defender_agent_name
48+
agent['agent'] for agent in scenario.agents
49+
if agent['name'] == defender_agent_name
5050
)
5151

5252
total_reward_defender = 0.0
@@ -78,7 +78,7 @@ def test_bfs_vs_bfs_state_and_reward() -> None:
7878
# If actions were performed, add them to respective list
7979
if attacker_node and attacker_node in attacker_state.step_performed_nodes:
8080
attacker_actions.append(attacker_node.full_name)
81-
assert attacker_node in defender_state.step_all_compromised_nodes
81+
assert attacker_node in defender_state.step_compromised_nodes
8282

8383
if defender_node and defender_node in defender_state.step_performed_nodes:
8484
defender_actions.append(defender_node.full_name)
@@ -194,12 +194,12 @@ def test_bfs_vs_bfs_state_and_reward_per_step_ttc() -> None:
194194
attacker_agent_name = "attacker1"
195195

196196
attacker_agent = next(
197-
agent.agent for agent in scenario.agents
198-
if agent.name == attacker_agent_name
197+
agent['agent'] for agent in scenario.agents
198+
if agent['name'] == attacker_agent_name
199199
)
200200
defender_agent = next(
201-
agent.agent for agent in scenario.agents
202-
if agent.name == defender_agent_name
201+
agent['agent'] for agent in scenario.agents
202+
if agent['name'] == defender_agent_name
203203
)
204204

205205
total_reward_defender = 0.0
@@ -230,7 +230,7 @@ def test_bfs_vs_bfs_state_and_reward_per_step_ttc() -> None:
230230
# If actions were performed, add them to respective list
231231
if attacker_node and attacker_node in attacker_state.step_performed_nodes:
232232
attacker_actions.append(attacker_node.full_name)
233-
assert attacker_node in defender_state.step_all_compromised_nodes
233+
assert attacker_node in defender_state.step_compromised_nodes
234234

235235
if defender_node and defender_node in \
236236
states['defender1'].step_performed_nodes:
@@ -325,12 +325,12 @@ def test_bfs_vs_bfs_state_and_reward_per_step_effort_based() -> None:
325325
attacker_agent_name = "attacker1"
326326

327327
attacker_agent = next(
328-
agent_info.agent for agent_info in scenario.agents
329-
if agent_info.name == attacker_agent_name
328+
agent['agent'] for agent in scenario.agents
329+
if agent['name'] == attacker_agent_name
330330
)
331331
defender_agent = next(
332-
agent_info.agent for agent_info in scenario.agents
333-
if agent_info.name == defender_agent_name
332+
agent['agent'] for agent in scenario.agents
333+
if agent['name'] == defender_agent_name
334334
)
335335

336336
total_reward_defender = 0.0
@@ -361,7 +361,7 @@ def test_bfs_vs_bfs_state_and_reward_per_step_effort_based() -> None:
361361
# If actions were performed, add them to respective list
362362
if attacker_node and attacker_node in attacker_state.step_performed_nodes:
363363
attacker_actions.append(attacker_node.full_name)
364-
assert attacker_node in defender_state.step_all_compromised_nodes
364+
assert attacker_node in defender_state.step_compromised_nodes
365365

366366
if defender_node and defender_node in defender_state.step_performed_nodes:
367367
defender_actions.append(defender_node.full_name)
@@ -432,12 +432,12 @@ def test_bfs_vs_bfs_state_and_reward_expected_value_ttc() -> None:
432432
attacker_agent_name = "attacker1"
433433

434434
attacker_agent = next(
435-
agent_info.agent for agent_info in scenario.agents
436-
if agent_info.name == attacker_agent_name
435+
agent['agent'] for agent in scenario.agents
436+
if agent['name'] == attacker_agent_name
437437
)
438438
defender_agent = next(
439-
agent_info.agent for agent_info in scenario.agents
440-
if agent_info.name == defender_agent_name
439+
agent['agent'] for agent in scenario.agents
440+
if agent['name'] == defender_agent_name
441441
)
442442

443443
total_reward_defender = 0.0
@@ -468,7 +468,7 @@ def test_bfs_vs_bfs_state_and_reward_expected_value_ttc() -> None:
468468
# If actions were performed, add them to respective list
469469
if attacker_node and attacker_node in attacker_state.step_performed_nodes:
470470
attacker_actions.append(attacker_node.full_name)
471-
assert attacker_node in defender_state.step_all_compromised_nodes
471+
assert attacker_node in defender_state.step_compromised_nodes
472472

473473
if defender_node and defender_node in defender_state.step_performed_nodes:
474474
defender_actions.append(defender_node.full_name)

0 commit comments

Comments
 (0)