Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions malsim/agents/defenders/heuristic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_next_action(

"""Return an action that disables a compromised node"""

self.compromised_nodes |= agent_state.step_all_compromised_nodes
self.compromised_nodes |= agent_state.step_compromised_nodes

selected_node_cost = math.inf
selected_node = None
Expand Down Expand Up @@ -92,7 +92,7 @@ def get_next_action(

"""Return an action that disables a compromised node"""

self.compromised_nodes |= agent_state.step_all_compromised_nodes
self.compromised_nodes |= agent_state.step_compromised_nodes

selected_node_cost = math.inf
selected_node = None
Expand Down
19 changes: 8 additions & 11 deletions malsim/envs/gym_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,10 @@ def __init__(self, scenario_file: str, **kwargs: Any) -> None:

def _register_attacker_agents(self, agents: list[dict[str, Any]]) -> None:
"""Register attackers in simulator"""
for agent_info in agents:
if agent_info['type'] == AgentType.ATTACKER:
for agent_config in agents:
if agent_config['type'] == AgentType.ATTACKER:
self.sim.register_attacker(
agent_info['name'],
agent_info['entry_points']
agent_config['name'], agent_config['entry_points']
)

def _create_attacker_decision_agents(
Expand All @@ -142,14 +141,12 @@ def _create_attacker_decision_agents(
"""Create decision agents for each attacker"""

attacker_agents = {}

for agent_info in agents:
if agent_info['type'] == AgentType.ATTACKER:
agent_name = agent_info['name']
agent_class = agent_info.get('agent_class')
if agent_class:
for agent_config in agents:
if agent_config['type'] == AgentType.ATTACKER:
agent_name = agent_config['name']
if agent_config['agent_class']:
attacker_agents[agent_name] = (
agent_class(
agent_config['agent_class'](
{'seed': seed, 'randomize': self.randomize}
)
)
Expand Down
71 changes: 46 additions & 25 deletions malsim/mal_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
make_node_unviable,
)

from malsim.scenario import AgentType, load_scenario
from malsim.scenario import AgentType, load_scenario, AgentConfig, AttackerAgentConfig, DefenderAgentConfig

if TYPE_CHECKING:
from malsim.scenario import Scenario
Expand Down Expand Up @@ -96,7 +96,7 @@ class MalSimDefenderState(MalSimAgentState):
@property
def step_all_compromised_nodes(self) -> frozenset[AttackGraphNode]:
print(
"Deprecated in mal-simulator 1.1.0, "
"'step_all_compromised_nodes' deprecated in mal-simulator 1.1.0, "
"please use 'step_compromised_nodes'"
)
return self.step_compromised_nodes
Expand Down Expand Up @@ -242,6 +242,32 @@ def from_scenario(
) -> MalSimulator:
"""Create a MalSimulator object from a Scenario"""

def register_agent_dict(agent_config: dict[str, Any]) -> None:
"""Register an agent specified in a dictionary"""
logger.warning(
"Having agent configs in dictionaries will be deprecated in "
"mal-simulator 1.1.0. Please use malsim.scenario.AgentConfig."
)
if agent_config['type'] == AgentType.ATTACKER:
sim.register_attacker(
agent_config['name'],
agent_config['entry_points'],
agent_config.get('goals')
)
elif agent_config['type'] == AgentType.DEFENDER:
sim.register_defender(agent_config['name'])

def register_agent_config(agent_config: AgentConfig) -> None:
"""Register an agent config in simulator"""
if isinstance(agent_config, AttackerAgentConfig):
sim.register_attacker(
agent_config.name,
agent_config.entry_points,
agent_config.goals
)
elif isinstance(agent_config, DefenderAgentConfig):
sim.register_defender(agent_config.name)

if isinstance(scenario, str):
# Load scenario if file was given
scenario = load_scenario(scenario)
Expand All @@ -259,15 +285,11 @@ def from_scenario(
)

if register_agents:
for agent_info in scenario.agents:
if agent_info['type'] == AgentType.ATTACKER:
sim.register_attacker(
agent_info['name'],
agent_info['entry_points'],
agent_info.get('goals')
)
elif agent_info['type'] == AgentType.DEFENDER:
sim.register_defender(agent_info['name'])
for agent_config in scenario.agents:
if isinstance(agent_config, dict):
register_agent_dict(agent_config)
elif isinstance(agent_config, AgentConfig):
register_agent_config(agent_config)

return sim

Expand Down Expand Up @@ -1016,7 +1038,7 @@ def _defender_step_reward(
- reward_mode: which way to calculate reward
"""
step_enabled_defenses = defender_state.step_performed_nodes
step_compromised_nodes = defender_state.step_all_compromised_nodes
step_compromised_nodes = defender_state.step_compromised_nodes

# Defender is penalized for compromised steps and enabled defenses
step_reward = - sum(
Expand Down Expand Up @@ -1090,7 +1112,7 @@ def step(
self.recording[self.cur_iter] = {}

# Populate these from the results for all agents' actions.
step_all_compromised_nodes: set[AttackGraphNode] = set()
step_compromised_nodes: set[AttackGraphNode] = set()
step_enabled_defenses: set[AttackGraphNode] = set()
step_nodes_made_unviable: set[AttackGraphNode] = set()

Expand All @@ -1108,7 +1130,7 @@ def step(
agent_compromised, agent_attempted = self._attacker_step(
attacker_state, actions.get(attacker_state.name, [])
)
step_all_compromised_nodes |= agent_compromised
step_compromised_nodes |= agent_compromised
self.recording[self.cur_iter][attacker_state.name] = (
list(agent_compromised)
)
Expand Down Expand Up @@ -1138,7 +1160,7 @@ def step(
# Update defender state
updated_defender_state = self._update_defender_state(
agent_state,
step_all_compromised_nodes,
step_compromised_nodes,
step_enabled_defenses,
step_nodes_made_unviable
)
Expand Down Expand Up @@ -1174,7 +1196,7 @@ def run_simulation(
Return selected actions by each agent in each step
"""
agent_actions: dict[str, list[AttackGraphNode]] = {}
total_rewards = {agent_dict['name']: 0.0 for agent_dict in agents}
total_rewards = {agent_config['name']: 0.0 for agent_config in agents}

logger.info("Starting CLI env simulator.")
states = sim.reset()
Expand All @@ -1184,9 +1206,9 @@ def run_simulation(
actions = {}

# Select actions for each agent
for agent_dict in agents:
decision_agent: Optional[DecisionAgent] = agent_dict.get('agent')
agent_name = agent_dict['name']
for agent_config in agents:
decision_agent: Optional[DecisionAgent] = agent_config['agent']
agent_name = agent_config['name']
if decision_agent is None:
print(
f'Agent "{agent_name}" has no decision agent class '
Expand All @@ -1209,17 +1231,16 @@ def run_simulation(

# Perform next step of simulation
states = sim.step(actions)
for agent_dict in agents:
agent_name = agent_dict['name']
total_rewards[agent_name] += sim.agent_reward(agent_name)
for agent_config in agents:
total_rewards[agent_config['name']] += sim.agent_reward(agent_config['name'])

print("---")

print(f"Simulation over after {sim.cur_iter} steps.")

# Print total rewards
for agent_dict in agents:
agent_name = agent_dict['name']
print(f'Total reward "{agent_name}"', total_rewards[agent_name])
for agent_config in agents:
agent_name = agent_config['name']
print(f'Total reward "{agent_name}"', total_rewards[agent_config['name']])

return agent_actions
Loading