Skip to content

Commit 9d0fad2

Browse files
author
Ervin T
authored
[tests] Add tests for core PyTorch files (#4292)
1 parent 02e35fd commit 9d0fad2

File tree

7 files changed

+482
-24
lines changed

7 files changed

+482
-24
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
import torch
3+
4+
from mlagents.trainers.torch.decoders import ValueHeads
5+
6+
7+
def test_valueheads():
8+
stream_names = [f"reward_signal_{num}" for num in range(5)]
9+
input_size = 5
10+
batch_size = 4
11+
12+
# Test default 1 value per head
13+
value_heads = ValueHeads(stream_names, input_size)
14+
input_data = torch.ones((batch_size, input_size))
15+
value_out, _ = value_heads(input_data) # Note: mean value will be removed shortly
16+
17+
for stream_name in stream_names:
18+
assert value_out[stream_name].shape == (batch_size,)
19+
20+
# Test that inputting the wrong size input will throw an error
21+
with pytest.raises(Exception):
22+
value_out = value_heads(torch.ones((batch_size, input_size + 2)))
23+
24+
# Test multiple values per head (e.g. discrete Q function)
25+
output_size = 4
26+
value_heads = ValueHeads(stream_names, input_size, output_size)
27+
input_data = torch.ones((batch_size, input_size))
28+
value_out, _ = value_heads(input_data)
29+
30+
for stream_name in stream_names:
31+
assert value_out[stream_name].shape == (batch_size, output_size)
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import pytest
2+
import torch
3+
4+
from mlagents.trainers.torch.distributions import (
5+
GaussianDistribution,
6+
MultiCategoricalDistribution,
7+
GaussianDistInstance,
8+
TanhGaussianDistInstance,
9+
CategoricalDistInstance,
10+
)
11+
12+
13+
@pytest.mark.parametrize("tanh_squash", [True, False])
14+
@pytest.mark.parametrize("conditional_sigma", [True, False])
15+
def test_gaussian_distribution(conditional_sigma, tanh_squash):
16+
torch.manual_seed(0)
17+
hidden_size = 16
18+
act_size = 4
19+
sample_embedding = torch.ones((1, 16))
20+
gauss_dist = GaussianDistribution(
21+
hidden_size,
22+
act_size,
23+
conditional_sigma=conditional_sigma,
24+
tanh_squash=tanh_squash,
25+
)
26+
27+
# Make sure backprop works
28+
force_action = torch.zeros((1, act_size))
29+
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
30+
31+
for _ in range(50):
32+
dist_inst = gauss_dist(sample_embedding)[0]
33+
if tanh_squash:
34+
assert isinstance(dist_inst, TanhGaussianDistInstance)
35+
else:
36+
assert isinstance(dist_inst, GaussianDistInstance)
37+
log_prob = dist_inst.log_prob(force_action)
38+
loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape))
39+
optimizer.zero_grad()
40+
loss.backward()
41+
optimizer.step()
42+
for prob in log_prob.flatten():
43+
assert prob == pytest.approx(-2, abs=0.1)
44+
45+
46+
def test_multi_categorical_distribution():
47+
torch.manual_seed(0)
48+
hidden_size = 16
49+
act_size = [3, 3, 4]
50+
sample_embedding = torch.ones((1, 16))
51+
gauss_dist = MultiCategoricalDistribution(hidden_size, act_size)
52+
53+
# Make sure backprop works
54+
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
55+
56+
def create_test_prob(size: int) -> torch.Tensor:
57+
test_prob = torch.tensor(
58+
[[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)]
59+
) # High prob for first action
60+
return test_prob.log()
61+
62+
for _ in range(100):
63+
dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size))))
64+
loss = 0
65+
for i, dist_inst in enumerate(dist_insts):
66+
assert isinstance(dist_inst, CategoricalDistInstance)
67+
log_prob = dist_inst.all_log_prob()
68+
test_log_prob = create_test_prob(act_size[i])
69+
# Force log_probs to match the high probability for the first action generated by
70+
# create_test_prob
71+
loss += torch.nn.functional.mse_loss(log_prob, test_log_prob)
72+
optimizer.zero_grad()
73+
loss.backward()
74+
optimizer.step()
75+
for dist_inst, size in zip(dist_insts, act_size):
76+
# Check that the log probs are close to the fake ones that we generated.
77+
test_log_probs = create_test_prob(size)
78+
for _prob, _test_prob in zip(
79+
dist_inst.all_log_prob().flatten().tolist(),
80+
test_log_probs.flatten().tolist(),
81+
):
82+
assert _prob == pytest.approx(_test_prob, abs=0.1)
83+
84+
# Test masks
85+
masks = []
86+
for branch in act_size:
87+
masks += [0] * (branch - 1) + [1]
88+
masks = torch.tensor([masks])
89+
dist_insts = gauss_dist(sample_embedding, masks=masks)
90+
for dist_inst in dist_insts:
91+
log_prob = dist_inst.all_log_prob()
92+
assert log_prob.flatten()[-1] == pytest.approx(0, abs=0.001)
93+
94+
95+
def test_gaussian_dist_instance():
96+
torch.manual_seed(0)
97+
act_size = 4
98+
dist_instance = GaussianDistInstance(
99+
torch.zeros(1, act_size), torch.ones(1, act_size)
100+
)
101+
action = dist_instance.sample()
102+
assert action.shape == (1, act_size)
103+
for log_prob in dist_instance.log_prob(torch.zeros((1, act_size))).flatten():
104+
# Log prob of standard normal at 0
105+
assert log_prob == pytest.approx(-0.919, abs=0.01)
106+
107+
for ent in dist_instance.entropy().flatten():
108+
# entropy of standard normal at 0
109+
assert ent == pytest.approx(2.83, abs=0.01)
110+
111+
112+
def test_tanh_gaussian_dist_instance():
113+
torch.manual_seed(0)
114+
act_size = 4
115+
dist_instance = GaussianDistInstance(
116+
torch.zeros(1, act_size), torch.ones(1, act_size)
117+
)
118+
for _ in range(10):
119+
action = dist_instance.sample()
120+
assert action.shape == (1, act_size)
121+
assert torch.max(action) < 1.0 and torch.min(action) > -1.0
122+
123+
124+
def test_categorical_dist_instance():
125+
torch.manual_seed(0)
126+
act_size = 4
127+
test_prob = torch.tensor(
128+
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)
129+
) # High prob for first action
130+
dist_instance = CategoricalDistInstance(test_prob)
131+
132+
for _ in range(10):
133+
action = dist_instance.sample()
134+
assert action.shape == (1,)
135+
assert action < act_size
136+
137+
# Make sure the first action as higher probability than the others.
138+
prob_first_action = dist_instance.log_prob(torch.tensor([0]))
139+
140+
for i in range(1, act_size):
141+
assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
from unittest import mock
3+
import pytest
4+
5+
from mlagents.trainers.torch.encoders import (
6+
VectorEncoder,
7+
VectorAndUnnormalizedInputEncoder,
8+
Normalizer,
9+
SimpleVisualEncoder,
10+
ResNetVisualEncoder,
11+
NatureVisualEncoder,
12+
)
13+
14+
15+
# This test will also reveal issues with states not being saved in the state_dict.
16+
def compare_models(module_1, module_2):
17+
is_same = True
18+
for key_item_1, key_item_2 in zip(
19+
module_1.state_dict().items(), module_2.state_dict().items()
20+
):
21+
# Compare tensors in state_dict and not the keys.
22+
is_same = torch.equal(key_item_1[1], key_item_2[1]) and is_same
23+
return is_same
24+
25+
26+
def test_normalizer():
27+
input_size = 2
28+
norm = Normalizer(input_size)
29+
30+
# These three inputs should mean to 0.5, and variance 2
31+
# with the steps starting at 1
32+
vec_input1 = torch.tensor([[1, 1]])
33+
vec_input2 = torch.tensor([[1, 1]])
34+
vec_input3 = torch.tensor([[0, 0]])
35+
norm.update(vec_input1)
36+
norm.update(vec_input2)
37+
norm.update(vec_input3)
38+
39+
# Test normalization
40+
for val in norm(vec_input1)[0]:
41+
assert val == pytest.approx(0.707, abs=0.001)
42+
43+
# Test copy normalization
44+
norm2 = Normalizer(input_size)
45+
assert not compare_models(norm, norm2)
46+
norm2.copy_from(norm)
47+
assert compare_models(norm, norm2)
48+
for val in norm2(vec_input1)[0]:
49+
assert val == pytest.approx(0.707, abs=0.001)
50+
51+
52+
@mock.patch("mlagents.trainers.torch.encoders.Normalizer")
53+
def test_vector_encoder(mock_normalizer):
54+
mock_normalizer_inst = mock.Mock()
55+
mock_normalizer.return_value = mock_normalizer_inst
56+
input_size = 64
57+
hidden_size = 128
58+
num_layers = 3
59+
normalize = False
60+
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize)
61+
output = vector_encoder(torch.ones((1, input_size)))
62+
assert output.shape == (1, hidden_size)
63+
64+
normalize = True
65+
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize)
66+
new_vec = torch.ones((1, input_size))
67+
vector_encoder.update_normalization(new_vec)
68+
69+
mock_normalizer.assert_called_with(input_size)
70+
mock_normalizer_inst.update.assert_called_with(new_vec)
71+
72+
vector_encoder2 = VectorEncoder(input_size, hidden_size, num_layers, normalize)
73+
vector_encoder.copy_normalization(vector_encoder2)
74+
mock_normalizer_inst.copy_from.assert_called_with(mock_normalizer_inst)
75+
76+
77+
@mock.patch("mlagents.trainers.torch.encoders.Normalizer")
78+
def test_vector_and_unnormalized_encoder(mock_normalizer):
79+
mock_normalizer_inst = mock.Mock()
80+
mock_normalizer.return_value = mock_normalizer_inst
81+
input_size = 64
82+
unnormalized_size = 32
83+
hidden_size = 128
84+
num_layers = 3
85+
normalize = True
86+
mock_normalizer_inst.return_value = torch.ones((1, input_size))
87+
vector_encoder = VectorAndUnnormalizedInputEncoder(
88+
input_size, hidden_size, unnormalized_size, num_layers, normalize
89+
)
90+
# Make sure normalizer is only called on input_size
91+
mock_normalizer.assert_called_with(input_size)
92+
normal_input = torch.ones((1, input_size))
93+
94+
unnormalized_input = torch.ones((1, 32))
95+
output = vector_encoder(normal_input, unnormalized_input)
96+
mock_normalizer_inst.assert_called_with(normal_input)
97+
assert output.shape == (1, hidden_size)
98+
99+
100+
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
101+
@pytest.mark.parametrize(
102+
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
103+
)
104+
def test_visual_encoder(vis_class, image_size):
105+
num_outputs = 128
106+
enc = vis_class(image_size[0], image_size[1], image_size[2], num_outputs)
107+
# Note: NCHW not NHWC
108+
sample_input = torch.ones((1, image_size[2], image_size[0], image_size[1]))
109+
encoding = enc(sample_input)
110+
assert encoding.shape == (1, num_outputs)

0 commit comments

Comments
 (0)