Skip to content

Commit

Permalink
fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
Limmen committed May 28, 2024
1 parent caed7d0 commit 90e0232
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 270 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
)
from csle_tolerance.util.intrusion_recovery_pomdp_util import IntrusionRecoveryPomdpUtil
import pytest_mock
import numpy as np


class TestIntrusionRecoveryPomdpConfigSuite:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def test__state_space(self) -> None:
:return: None
"""
assert (
isinstance(item, int) for item in IntrusionRecoveryPomdpUtil.state_space()
)
assert (isinstance(item, int) for item in IntrusionRecoveryPomdpUtil.state_space())
assert IntrusionRecoveryPomdpUtil.state_space() is not None
assert IntrusionRecoveryPomdpUtil.state_space() == [0, 1, 2]

Expand All @@ -40,9 +38,7 @@ def test_action_space(self) -> None:
:return: None
"""
assert (
isinstance(item, int) for item in IntrusionRecoveryPomdpUtil.action_space()
)
assert (isinstance(item, int) for item in IntrusionRecoveryPomdpUtil.action_space())
assert IntrusionRecoveryPomdpUtil.action_space() is not None
assert IntrusionRecoveryPomdpUtil.action_space() == [0, 1]

Expand Down Expand Up @@ -79,10 +75,7 @@ def test_cost_tensor(self) -> None:
actions = [0]
negate = False
expected = [[0, 0.5]]
assert (
IntrusionRecoveryPomdpUtil.cost_tensor(eta, states, actions, negate)
== expected
)
assert IntrusionRecoveryPomdpUtil.cost_tensor(eta, states, actions, negate) == expected

def test_observation_function(self) -> None:
"""
Expand All @@ -93,9 +86,7 @@ def test_observation_function(self) -> None:
s = 1
o = 1
num_observations = 2
assert round(
IntrusionRecoveryPomdpUtil.observation_function(s, o, num_observations), 1
)
assert round(IntrusionRecoveryPomdpUtil.observation_function(s, o, num_observations), 1)

def test_observation_tensor(self) -> None:
"""
Expand Down Expand Up @@ -126,15 +117,7 @@ def test_transition_function(self) -> None:
p_c_1 = 0.1
p_c_2 = 0.2
p_u = 0.5
assert (
round(
IntrusionRecoveryPomdpUtil.transition_function(
s, s_prime, a, p_a, p_c_1, p_c_2, p_u
),
1,
)
== 0.2
)
assert (round(IntrusionRecoveryPomdpUtil.transition_function(s, s_prime, a, p_a, p_c_1, p_c_2, p_u), 1) == 0.2)

def test_transition_function_game(self) -> None:
"""
Expand All @@ -148,15 +131,7 @@ def test_transition_function_game(self) -> None:
a2 = 1
p_a = 0.2
p_c_1 = 0.1
assert (
round(
IntrusionRecoveryPomdpUtil.transition_function_game(
s, s_prime, a1, a2, p_a, p_c_1
),
2,
)
== 0.18
)
assert (round(IntrusionRecoveryPomdpUtil.transition_function_game(s, s_prime, a1, a2, p_a, p_c_1), 2) == 0.18)

def test_transition_tensor(self) -> None:
"""
Expand All @@ -171,19 +146,15 @@ def test_transition_tensor(self) -> None:
p_c_2 = 0.2
p_u = 0.5
expected = [[[0.7, 0.2, 0.1], [0.4, 0.4, 0.2], [0, 0, 1.0]]]
transition_tensor = IntrusionRecoveryPomdpUtil.transition_tensor(
states, actions, p_a, p_c_1, p_c_2, p_u
)
transition_tensor = IntrusionRecoveryPomdpUtil.transition_tensor(states, actions, p_a, p_c_1, p_c_2, p_u)
for i in range(len(transition_tensor)):
for j in range(len(transition_tensor[i])):
for k in range(len(transition_tensor[i][j])):
transition_tensor[i][j][k] = round(transition_tensor[i][j][k], 1)
assert transition_tensor == expected
states = [0, 1]
with pytest.raises(AssertionError):
transition_tensor = IntrusionRecoveryPomdpUtil.transition_tensor(
states, actions, p_a, p_c_1, p_c_2, p_u
)
IntrusionRecoveryPomdpUtil.transition_tensor(states, actions, p_a, p_c_1, p_c_2, p_u)

def test_transition_tensor_game(self) -> None:
"""
Expand All @@ -196,14 +167,12 @@ def test_transition_tensor_game(self) -> None:
attacker_actions = [0, 1]
p_a = 0.5
p_c_1 = 0.3
result = IntrusionRecoveryPomdpUtil.transition_tensor_game(
states, defender_actions, attacker_actions, p_a, p_c_1
)
result = IntrusionRecoveryPomdpUtil.transition_tensor_game(states, defender_actions, attacker_actions, p_a,
p_c_1)
assert len(result) == len(defender_actions)
assert all(len(a1) == len(attacker_actions) for a1 in result)
assert all(len(a2) == len(states) for a1 in result for a2 in a1)
assert all(len(s) == len(states) for a1 in result for a2 in a1 for s in a2)

assert result[0][1][0][0] == (1 - p_a) * (1 - p_c_1)
assert result[1][0][1][1] == 0
assert result[1][1][2][2] == 1.0
Expand Down Expand Up @@ -234,12 +203,8 @@ def test_sampe_next_observation(self) -> None:
observation_tensor = [[0.8, 0.2], [0.4, 0.6]]
s_prime = 1
observations = [0, 1]
assert isinstance(
IntrusionRecoveryPomdpUtil.sample_next_observation(
observation_tensor, s_prime, observations
),
int,
)
assert isinstance(IntrusionRecoveryPomdpUtil.sample_next_observation(observation_tensor, s_prime, observations),
int)

def test_bayes_filter(self) -> None:
"""
Expand All @@ -256,22 +221,9 @@ def test_bayes_filter(self) -> None:
observation_tensor = [[0.8, 0.2], [0.4, 0.6]]
transition_tensor = [[[0.6, 0.4], [0.1, 0.9]]]
b_prime_s_prime = 0.7
assert (
round(
IntrusionRecoveryPomdpUtil.bayes_filter(
s_prime,
o,
a,
b,
states,
observations,
observation_tensor,
transition_tensor,
),
1,
)
== b_prime_s_prime
)
assert (round(IntrusionRecoveryPomdpUtil.bayes_filter(s_prime, o, a, b, states, observations,
observation_tensor, transition_tensor), 1)
== b_prime_s_prime)

def test_p_o_given_b_a1_a2(self) -> None:
"""
Expand All @@ -286,15 +238,8 @@ def test_p_o_given_b_a1_a2(self) -> None:
observation_tensor = [[0.8, 0.2], [0.4, 0.6]]
transition_tensor = [[[0.6, 0.4], [0.1, 0.9]]]
expected = 0.5
assert (
round(
IntrusionRecoveryPomdpUtil.p_o_given_b_a1_a2(
o, b, a, states, transition_tensor, observation_tensor
),
1,
)
== expected
)
assert (round(IntrusionRecoveryPomdpUtil.p_o_given_b_a1_a2(o, b, a, states, transition_tensor,
observation_tensor), 1) == expected)

def test_next_belief(self) -> None:
"""
Expand All @@ -309,23 +254,8 @@ def test_next_belief(self) -> None:
observations = [0, 1]
observation_tensor = [[0.8, 0.2], [0.4, 0.6]]
transition_tensor = [[[0.3, 0.7], [0.6, 0.4]]]
assert (
round(
sum(
IntrusionRecoveryPomdpUtil.next_belief(
o,
a,
b,
states,
observations,
observation_tensor,
transition_tensor,
)
),
1,
)
== 1
)
assert (round(sum(IntrusionRecoveryPomdpUtil.next_belief(o, a, b, states, observations, observation_tensor,
transition_tensor)), 1) == 1)

def test_pomdp_solver_file(self) -> None:
"""
Expand All @@ -334,33 +264,14 @@ def test_pomdp_solver_file(self) -> None:
:return: None
"""

assert (
IntrusionRecoveryPomdpUtil.pomdp_solver_file(
IntrusionRecoveryPomdpConfig(
eta=0.1,
p_a=0.2,
p_c_1=0.2,
p_c_2=0.3,
p_u=0.3,
BTR=1,
negate_costs=True,
seed=1,
discount_factor=0.5,
states=[0, 1],
actions=[0],
observations=[0, 1],
cost_tensor=[[0.1, 0.5], [0.5, 0.6]],
observation_tensor=[[0.8, 0.2], [0.4, 0.6]],
transition_tensor=[[[0.8, 0.2], [0.6, 0.4]]],
b1=[0.3, 0.7],
T=3,
simulation_env_name="env",
gym_env_name="gym",
max_horizon=np.inf,
)
)
is not None
)
assert (IntrusionRecoveryPomdpUtil.pomdp_solver_file(
IntrusionRecoveryPomdpConfig(eta=0.1, p_a=0.2, p_c_1=0.2, p_c_2=0.3, p_u=0.3, BTR=1, negate_costs=True,
seed=1, discount_factor=0.5, states=[0, 1], actions=[0], observations=[0, 1],
cost_tensor=[[0.1, 0.5], [0.5, 0.6]],
observation_tensor=[[0.8, 0.2], [0.4, 0.6]],
transition_tensor=[[[0.8, 0.2], [0.6, 0.4]]], b1=[0.3, 0.7], T=3,
simulation_env_name="env", gym_env_name="gym", max_horizon=np.inf))
is not None)

def test_sample_next_state_game(self) -> None:
"""
Expand Down Expand Up @@ -444,9 +355,7 @@ def test_generate_transitions(self) -> None:
gym_env_name="gym_env",
max_horizon=1000,
)
assert (
IntrusionRecoveryPomdpUtil.generate_transitions(dto)[0] == "0 0 0 0 0 0.06"
)
assert IntrusionRecoveryPomdpUtil.generate_transitions(dto)[0] == "0 0 0 0 0 0.06"

def test_generate_rewards(self) -> None:
"""
Expand Down Expand Up @@ -502,7 +411,11 @@ def test_generate_rewards(self) -> None:
assert IntrusionRecoveryPomdpUtil.generate_rewards(dto)[0] == "0 0 0 -1"

def test_generate_os_posg_game_file(self) -> None:
""" """
"""
Tests the generate_os_posg_game function
:return: None
"""

states = [0, 1, 2]
actions = [0, 1]
Expand Down Expand Up @@ -580,24 +493,13 @@ def test_generate_os_posg_game_file(self) -> None:

output_lines = game_file_str.split("\n")

assert (
output_lines[0] == expected_game_description
), f"Game description mismatch: {output_lines[0]}"
assert (
output_lines[1:4] == expected_state_descriptions
), f"State descriptions mismatch: {output_lines[1:4]}"
assert (
output_lines[4:6] == expected_player_1_actions
), f"Player 1 actions mismatch: {output_lines[4:6]}"
assert (
output_lines[6:8] == expected_player_2_actions
), f"Player 2 actions mismatch: {output_lines[6:8]}"
assert (
output_lines[8:10] == expected_obs_descriptions
), f"Observation descriptions mismatch: {output_lines[8:10]}"
assert (
output_lines[10:13] == expected_player_2_legal_actions
), f"Player 2 legal actions mismatch: {output_lines[10:13]}"
assert (
output_lines[13:14] == expected_player_1_legal_actions
), f"Player 1 legal actions mismatch: {output_lines[13:14]}"
assert (output_lines[0] == expected_game_description), f"Game description mismatch: {output_lines[0]}"
assert (output_lines[1:4] == expected_state_descriptions), f"State descriptions mismatch: {output_lines[1:4]}"
assert (output_lines[4:6] == expected_player_1_actions), f"Player 1 actions mismatch: {output_lines[4:6]}"
assert (output_lines[6:8] == expected_player_2_actions), f"Player 2 actions mismatch: {output_lines[6:8]}"
assert (output_lines[8:10] == expected_obs_descriptions), \
f"Observation descriptions mismatch: {output_lines[8:10]}"
assert (output_lines[10:13] == expected_player_2_legal_actions), \
f"Player 2 legal actions mismatch: {output_lines[10:13]}"
assert (output_lines[13:14] == expected_player_1_legal_actions), \
f"Player 1 legal actions mismatch: {output_lines[13:14]}"
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def step(self, action_profile: Tuple[int, Tuple[npt.NDArray[Any], int]]) \
a1, a2_profile = action_profile
pi2, a2 = a2_profile
assert pi2.shape[0] == len(self.config.S)
assert pi2.shape[1] == len(self.config.A1)
assert pi2.shape[1] == len(self.config.A2)
done = False
info: Dict[str, Any] = {}

Expand All @@ -83,8 +83,7 @@ def step(self, action_profile: Tuple[int, Tuple[npt.NDArray[Any], int]]) \
else:
# Compute r, s', b',o'
r = self.config.R[self.state.l - 1][a1][a2][self.state.s]
self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2,
T=self.config.T,
self.state.s = StoppingGameUtil.sample_next_state(l=self.state.l, a1=a1, a2=a2, T=self.config.T,
S=self.config.S, s=self.state.s)
o = StoppingGameUtil.sample_next_observation(Z=self.config.Z,
O=self.config.O, s_prime=self.state.s)
Expand Down
Loading

0 comments on commit 90e0232

Please sign in to comment.