Skip to content

Commit

Permalink
new test
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Mar 19, 2024
1 parent ad39920 commit c45d945
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions tests/buffer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from torch import optim

from lge.buffer import LGEBuffer
from lge.modules.ae_module import AEModule, CNNAEModule
from lge.modules.forward_module import CNNForwardModule, ForwardModule
from lge.modules.ae_module import AEModule
from lge.modules.forward_module import ForwardModule
from lge.modules.inverse_module import CNNInverseModule, InverseModule
from lge.utils import get_shape, get_size, preprocess
from tests.utils import DummyEnv
Expand Down
4 changes: 2 additions & 2 deletions tests/learners_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from lge.buffer import LGEBuffer
from lge.learners import AEModuleLearner, ForwardModuleLearner, InverseModuleLearner
from lge.modules.ae_module import AEModule, CNNAEModule
from lge.modules.forward_module import CNNForwardModule, ForwardModule
from lge.modules.ae_module import AEModule
from lge.modules.forward_module import ForwardModule
from lge.modules.inverse_module import CNNInverseModule, InverseModule
from lge.utils import get_shape, get_size, preprocess
from tests.utils import DummyEnv
Expand Down
14 changes: 7 additions & 7 deletions tests/lge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from lge import LatentGoExplore
from lge.buffer import LGEBuffer
from lge.lge import Goalify
from lge.modules.ae_module import AEModule, CNNAEModule
from lge.modules.forward_module import CNNForwardModule, ForwardModule
from lge.lge import VecGoalify
from lge.modules.ae_module import AEModule
from lge.modules.forward_module import ForwardModule
from lge.modules.inverse_module import CNNInverseModule, InverseModule
from lge.utils import get_shape, get_size
from tests.utils import BitFlippingEnv, DummyEnv
Expand Down Expand Up @@ -45,7 +45,7 @@
@pytest.mark.parametrize("action_space", ACTION_SPACES)
def test_goalify(observation_space, action_space):
env = DummyEnv(observation_space, action_space)
env = Goalify(env)
env = VecGoalify(env)
assert "observation" in env.observation_space.keys()
assert "goal" in env.observation_space.keys()
assert env.observation_space["observation"].__class__ == env.observation_space["goal"].__class__
Expand All @@ -56,7 +56,7 @@ def test_goalify(observation_space, action_space):
@pytest.mark.parametrize("action_space", ACTION_SPACES)
def test_reset_no_buffer(observation_space, action_space):
env = DummyEnv(observation_space, action_space)
env = Goalify(env)
env = VecGoalify(env)
with pytest.raises(AssertionError):
env.reset()

Expand Down Expand Up @@ -89,7 +89,7 @@ def test_goalify_reset(observation_space, action_space, module_class):
# Create environment
def env_func():
env = DummyEnv(observation_space, action_space)
env = Goalify(env)
env = VecGoalify(env)
return env

venv = make_vec_env(env_func, N_ENVS)
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_goalify_step(action_type, observation_type, module_class):
# Create environment
def env_func():
env = BitFlippingEnv(8, action_type, observation_type)
env = Goalify(env, distance_threshold=distance_threshold)
env = VecGoalify(env, distance_threshold=distance_threshold)
return env

venv = make_vec_env(env_func, N_ENVS)
Expand Down
4 changes: 2 additions & 2 deletions tests/modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import torch
from stable_baselines3.common.utils import set_random_seed

from lge.modules.ae_module import AEModule, CNNAEModule
from lge.modules.forward_module import CNNForwardModule, ForwardModule
from lge.modules.ae_module import AEModule
from lge.modules.forward_module import ForwardModule
from lge.modules.inverse_module import CNNInverseModule, InverseModule

BATCH_SIZE = 32
Expand Down

0 comments on commit c45d945

Please sign in to comment.