|
| 1 | +import pytest |
| 2 | + |
| 3 | +import torch |
| 4 | +from mlagents.trainers.torch.networks import ( |
| 5 | + NetworkBody, |
| 6 | + ValueNetwork, |
| 7 | + SimpleActor, |
| 8 | + SharedActorCritic, |
| 9 | + SeparateActorCritic, |
| 10 | +) |
| 11 | +from mlagents.trainers.settings import NetworkSettings |
| 12 | +from mlagents_envs.base_env import ActionType |
| 13 | +from mlagents.trainers.torch.distributions import ( |
| 14 | + GaussianDistInstance, |
| 15 | + CategoricalDistInstance, |
| 16 | +) |
| 17 | + |
| 18 | + |
| 19 | +def test_networkbody_vector(): |
| 20 | + obs_size = 4 |
| 21 | + network_settings = NetworkSettings() |
| 22 | + obs_shapes = [(obs_size,)] |
| 23 | + |
| 24 | + networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2) |
| 25 | + optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
| 26 | + sample_obs = torch.ones((1, obs_size)) |
| 27 | + sample_act = torch.ones((1, 2)) |
| 28 | + |
| 29 | + for _ in range(100): |
| 30 | + encoded, _ = networkbody([sample_obs], [], sample_act) |
| 31 | + assert encoded.shape == (1, network_settings.hidden_units) |
| 32 | + # Try to force output to 1 |
| 33 | + loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
| 34 | + optimizer.zero_grad() |
| 35 | + loss.backward() |
| 36 | + optimizer.step() |
| 37 | + # In the last step, values should be close to 1 |
| 38 | + for _enc in encoded.flatten(): |
| 39 | + assert _enc == pytest.approx(1.0, abs=0.1) |
| 40 | + |
| 41 | + |
| 42 | +def test_networkbody_lstm(): |
| 43 | + obs_size = 4 |
| 44 | + seq_len = 16 |
| 45 | + network_settings = NetworkSettings( |
| 46 | + memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4) |
| 47 | + ) |
| 48 | + obs_shapes = [(obs_size,)] |
| 49 | + |
| 50 | + networkbody = NetworkBody(obs_shapes, network_settings) |
| 51 | + optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
| 52 | + sample_obs = torch.ones((1, seq_len, obs_size)) |
| 53 | + |
| 54 | + for _ in range(100): |
| 55 | + encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4)) |
| 56 | + # Try to force output to 1 |
| 57 | + loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
| 58 | + optimizer.zero_grad() |
| 59 | + loss.backward() |
| 60 | + optimizer.step() |
| 61 | + # In the last step, values should be close to 1 |
| 62 | + for _enc in encoded.flatten(): |
| 63 | + assert _enc == pytest.approx(1.0, abs=0.1) |
| 64 | + |
| 65 | + |
| 66 | +def test_networkbody_visual(): |
| 67 | + vec_obs_size = 4 |
| 68 | + obs_size = (84, 84, 3) |
| 69 | + network_settings = NetworkSettings() |
| 70 | + obs_shapes = [(vec_obs_size,), obs_size] |
| 71 | + torch.random.manual_seed(0) |
| 72 | + |
| 73 | + networkbody = NetworkBody(obs_shapes, network_settings) |
| 74 | + optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3) |
| 75 | + sample_obs = torch.ones((1, 84, 84, 3)) |
| 76 | + sample_vec_obs = torch.ones((1, vec_obs_size)) |
| 77 | + |
| 78 | + for _ in range(100): |
| 79 | + encoded, _ = networkbody([sample_vec_obs], [sample_obs]) |
| 80 | + assert encoded.shape == (1, network_settings.hidden_units) |
| 81 | + # Try to force output to 1 |
| 82 | + loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape)) |
| 83 | + optimizer.zero_grad() |
| 84 | + loss.backward() |
| 85 | + optimizer.step() |
| 86 | + # In the last step, values should be close to 1 |
| 87 | + for _enc in encoded.flatten(): |
| 88 | + assert _enc == pytest.approx(1.0, abs=0.1) |
| 89 | + |
| 90 | + |
| 91 | +def test_valuenetwork(): |
| 92 | + obs_size = 4 |
| 93 | + num_outputs = 2 |
| 94 | + network_settings = NetworkSettings() |
| 95 | + obs_shapes = [(obs_size,)] |
| 96 | + |
| 97 | + stream_names = [f"stream_name{n}" for n in range(4)] |
| 98 | + value_net = ValueNetwork( |
| 99 | + stream_names, obs_shapes, network_settings, outputs_per_stream=num_outputs |
| 100 | + ) |
| 101 | + optimizer = torch.optim.Adam(value_net.parameters(), lr=3e-3) |
| 102 | + |
| 103 | + for _ in range(50): |
| 104 | + sample_obs = torch.ones((1, obs_size)) |
| 105 | + values, _ = value_net([sample_obs], []) |
| 106 | + loss = 0 |
| 107 | + for s_name in stream_names: |
| 108 | + assert values[s_name].shape == (1, num_outputs) |
| 109 | + # Try to force output to 1 |
| 110 | + loss += torch.nn.functional.mse_loss( |
| 111 | + values[s_name], torch.ones((1, num_outputs)) |
| 112 | + ) |
| 113 | + |
| 114 | + optimizer.zero_grad() |
| 115 | + loss.backward() |
| 116 | + optimizer.step() |
| 117 | + # In the last step, values should be close to 1 |
| 118 | + for value in values.values(): |
| 119 | + for _out in value: |
| 120 | + assert _out[0] == pytest.approx(1.0, abs=0.1) |
| 121 | + |
| 122 | + |
| 123 | +@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS]) |
| 124 | +def test_simple_actor(action_type): |
| 125 | + obs_size = 4 |
| 126 | + network_settings = NetworkSettings() |
| 127 | + obs_shapes = [(obs_size,)] |
| 128 | + act_size = [2] |
| 129 | + masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1)) |
| 130 | + actor = SimpleActor(obs_shapes, network_settings, action_type, act_size) |
| 131 | + # Test get_dist |
| 132 | + sample_obs = torch.ones((1, obs_size)) |
| 133 | + dists, _ = actor.get_dists([sample_obs], [], masks=masks) |
| 134 | + for dist in dists: |
| 135 | + if action_type == ActionType.CONTINUOUS: |
| 136 | + assert isinstance(dist, GaussianDistInstance) |
| 137 | + else: |
| 138 | + assert isinstance(dist, CategoricalDistInstance) |
| 139 | + |
| 140 | + # Test sample_actions |
| 141 | + actions = actor.sample_action(dists) |
| 142 | + for act in actions: |
| 143 | + if action_type == ActionType.CONTINUOUS: |
| 144 | + assert act.shape == (1, act_size[0]) |
| 145 | + else: |
| 146 | + assert act.shape == (1, 1) |
| 147 | + |
| 148 | + # Test forward |
| 149 | + actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward( |
| 150 | + [sample_obs], [], masks=masks |
| 151 | + ) |
| 152 | + for act in actions: |
| 153 | + if action_type == ActionType.CONTINUOUS: |
| 154 | + assert act.shape == ( |
| 155 | + act_size[0], |
| 156 | + 1, |
| 157 | + ) # This is different from above for ONNX export |
| 158 | + else: |
| 159 | + assert act.shape == (1, 1) |
| 160 | + |
| 161 | + # TODO: Once export works properly. fix the shapes here. |
| 162 | + assert mem_size == 0 |
| 163 | + assert is_cont == int(action_type == ActionType.CONTINUOUS) |
| 164 | + assert act_size_vec == torch.tensor(act_size) |
| 165 | + |
| 166 | + |
| 167 | +@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic]) |
| 168 | +@pytest.mark.parametrize("lstm", [True, False]) |
| 169 | +def test_actor_critic(ac_type, lstm): |
| 170 | + obs_size = 4 |
| 171 | + network_settings = NetworkSettings( |
| 172 | + memory=NetworkSettings.MemorySettings() if lstm else None |
| 173 | + ) |
| 174 | + obs_shapes = [(obs_size,)] |
| 175 | + act_size = [2] |
| 176 | + stream_names = [f"stream_name{n}" for n in range(4)] |
| 177 | + actor = ac_type( |
| 178 | + obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names |
| 179 | + ) |
| 180 | + if lstm: |
| 181 | + sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) |
| 182 | + memories = torch.ones( |
| 183 | + ( |
| 184 | + 1, |
| 185 | + network_settings.memory.sequence_length, |
| 186 | + network_settings.memory.memory_size, |
| 187 | + ) |
| 188 | + ) |
| 189 | + else: |
| 190 | + sample_obs = torch.ones((1, obs_size)) |
| 191 | + memories = None |
| 192 | + # Test critic pass |
| 193 | + value_out = actor.critic_pass([sample_obs], [], memories=memories) |
| 194 | + for stream in stream_names: |
| 195 | + if lstm: |
| 196 | + assert value_out[stream].shape == (network_settings.memory.sequence_length,) |
| 197 | + else: |
| 198 | + assert value_out[stream].shape == (1,) |
| 199 | + |
| 200 | + # Test get_dist_and_value |
| 201 | + dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories) |
| 202 | + for dist in dists: |
| 203 | + assert isinstance(dist, GaussianDistInstance) |
| 204 | + for stream in stream_names: |
| 205 | + if lstm: |
| 206 | + assert value_out[stream].shape == (network_settings.memory.sequence_length,) |
| 207 | + else: |
| 208 | + assert value_out[stream].shape == (1,) |
0 commit comments