diff --git a/tests/buffer_test.py b/tests/buffer_test.py index b5581a9..1d6ada8 100644 --- a/tests/buffer_test.py +++ b/tests/buffer_test.py @@ -8,8 +8,8 @@ from torch import optim from lge.buffer import LGEBuffer -from lge.modules.ae_module import AEModule, CNNAEModule -from lge.modules.forward_module import CNNForwardModule, ForwardModule +from lge.modules.ae_module import AEModule +from lge.modules.forward_module import ForwardModule from lge.modules.inverse_module import CNNInverseModule, InverseModule from lge.utils import get_shape, get_size, preprocess from tests.utils import DummyEnv diff --git a/tests/learners_test.py b/tests/learners_test.py index 77db6c3..d33cdb9 100644 --- a/tests/learners_test.py +++ b/tests/learners_test.py @@ -7,8 +7,8 @@ from lge.buffer import LGEBuffer from lge.learners import AEModuleLearner, ForwardModuleLearner, InverseModuleLearner -from lge.modules.ae_module import AEModule, CNNAEModule -from lge.modules.forward_module import CNNForwardModule, ForwardModule +from lge.modules.ae_module import AEModule +from lge.modules.forward_module import ForwardModule from lge.modules.inverse_module import CNNInverseModule, InverseModule from lge.utils import get_shape, get_size, preprocess from tests.utils import DummyEnv diff --git a/tests/lge_test.py b/tests/lge_test.py index 3540322..ac6ed0d 100644 --- a/tests/lge_test.py +++ b/tests/lge_test.py @@ -8,9 +8,9 @@ from lge import LatentGoExplore from lge.buffer import LGEBuffer -from lge.lge import Goalify -from lge.modules.ae_module import AEModule, CNNAEModule -from lge.modules.forward_module import CNNForwardModule, ForwardModule +from lge.lge import VecGoalify +from lge.modules.ae_module import AEModule +from lge.modules.forward_module import ForwardModule from lge.modules.inverse_module import CNNInverseModule, InverseModule from lge.utils import get_shape, get_size from tests.utils import BitFlippingEnv, DummyEnv @@ -45,7 +45,7 @@ @pytest.mark.parametrize("action_space", ACTION_SPACES) def test_goalify(observation_space, action_space): env = DummyEnv(observation_space, action_space) - env = Goalify(env) + env = VecGoalify(env) assert "observation" in env.observation_space.keys() assert "goal" in env.observation_space.keys() assert env.observation_space["observation"].__class__ == env.observation_space["goal"].__class__ @@ -56,7 +56,7 @@ def test_goalify(observation_space, action_space): @pytest.mark.parametrize("action_space", ACTION_SPACES) def test_reset_no_buffer(observation_space, action_space): env = DummyEnv(observation_space, action_space) - env = Goalify(env) + env = VecGoalify(env) with pytest.raises(AssertionError): env.reset() @@ -89,7 +89,7 @@ def test_goalify_reset(observation_space, action_space, module_class): # Create environment def env_func(): env = DummyEnv(observation_space, action_space) - env = Goalify(env) + env = VecGoalify(env) return env venv = make_vec_env(env_func, N_ENVS) @@ -129,7 +129,7 @@ def test_goalify_step(action_type, observation_type, module_class): # Create environment def env_func(): env = BitFlippingEnv(8, action_type, observation_type) - env = Goalify(env, distance_threshold=distance_threshold) + env = VecGoalify(env, distance_threshold=distance_threshold) return env venv = make_vec_env(env_func, N_ENVS) diff --git a/tests/modules_test.py b/tests/modules_test.py index e602252..69a1f13 100644 --- a/tests/modules_test.py +++ b/tests/modules_test.py @@ -2,8 +2,8 @@ import torch from stable_baselines3.common.utils import set_random_seed -from lge.modules.ae_module import AEModule, CNNAEModule -from lge.modules.forward_module import CNNForwardModule, ForwardModule +from lge.modules.ae_module import AEModule +from lge.modules.forward_module import ForwardModule from lge.modules.inverse_module import CNNInverseModule, InverseModule BATCH_SIZE = 32