1
- ing System ;
1
+ using System ;
2
2
using System . Collections ;
3
3
using System . Collections . Generic ;
4
4
using UnityEngine ;
13
13
14
14
namespace Unity . MLAgents . Tests
15
15
{
16
- internal class TestPolicy : IPolicy
17
- {
18
- public Action OnRequestDecision ;
19
- ObservationWriter m_ObsWriter = new ObservationWriter ( ) ;
20
- static ActionSpec s_ActionSpec = ActionSpec . MakeContinuous ( 1 ) ;
21
- static ActionBuffers s_EmptyActionBuffers = new ActionBuffers ( new float [ 1 ] , Array . Empty < int > ( ) ) ;
22
- public void RequestDecision ( AgentInfo info , List < ISensor > sensors )
23
- {
24
- foreach ( var sensor in sensors )
25
- {
26
- sensor . GetObservationProto ( m_ObsWriter ) ;
27
- }
28
- OnRequestDecision ? . Invoke ( ) ;
29
- }
30
-
31
- public ref readonly ActionBuffers DecideAction ( ) { return ref s_EmptyActionBuffers ; }
32
-
33
- public void Dispose ( ) { }
34
- }
35
-
36
- public class TestAgent : Agent
37
- {
38
- internal AgentInfo _Info
39
- {
40
- get
41
- {
42
- return ( AgentInfo ) typeof ( Agent ) . GetField ( "m_Info" , BindingFlags . Instance | BindingFlags . NonPublic ) . GetValue ( this ) ;
43
- }
44
- set
45
- {
46
- typeof ( Agent ) . GetField ( "m_Info" , BindingFlags . Instance | BindingFlags . NonPublic ) . SetValue ( this , value ) ;
47
- }
48
- }
49
-
50
- internal void SetPolicy ( IPolicy policy )
51
- {
52
- typeof ( Agent ) . GetField ( "m_Brain" , BindingFlags . Instance | BindingFlags . NonPublic ) . SetValue ( this , policy ) ;
53
- }
54
-
55
- internal IPolicy GetPolicy ( )
56
- {
57
- return ( IPolicy ) typeof ( Agent ) . GetField ( "m_Brain" , BindingFlags . Instance | BindingFlags . NonPublic ) . GetValue ( this ) ;
58
- }
59
-
60
- public int initializeAgentCalls ;
61
- public int collectObservationsCalls ;
62
- public int collectObservationsCallsForEpisode ;
63
- public int agentActionCalls ;
64
- public int agentActionCallsForEpisode ;
65
- public int agentOnEpisodeBeginCalls ;
66
- public int heuristicCalls ;
67
- public TestSensor sensor1 ;
68
- public TestSensor sensor2 ;
69
-
70
- [ Observable ( "observableFloat" ) ]
71
- public float observableFloat ;
72
-
73
- public override void Initialize ( )
74
- {
75
- initializeAgentCalls += 1 ;
76
-
77
- // Add in some custom Sensors so we can confirm they get sorted as expected.
78
- sensor1 = new TestSensor ( "testsensor1" ) ;
79
- sensor2 = new TestSensor ( "testsensor2" ) ;
80
- sensor2 . compressionType = SensorCompressionType . PNG ;
81
-
82
- sensors . Add ( sensor2 ) ;
83
- sensors . Add ( sensor1 ) ;
84
- }
85
-
86
- public override void CollectObservations ( VectorSensor sensor )
87
- {
88
- collectObservationsCalls += 1 ;
89
- collectObservationsCallsForEpisode += 1 ;
90
- sensor . AddObservation ( collectObservationsCallsForEpisode ) ;
91
- }
92
-
93
- public override void OnActionReceived ( ActionBuffers buffers )
94
- {
95
- agentActionCalls += 1 ;
96
- agentActionCallsForEpisode += 1 ;
97
- AddReward ( 0.1f ) ;
98
- }
99
-
100
- public override void OnEpisodeBegin ( )
101
- {
102
- agentOnEpisodeBeginCalls += 1 ;
103
- collectObservationsCallsForEpisode = 0 ;
104
- agentActionCallsForEpisode = 0 ;
105
- }
106
-
107
- public override void Heuristic ( in ActionBuffers actionsOut )
108
- {
109
- var obs = GetObservations ( ) ;
110
- var continuousActions = actionsOut . ContinuousActions ;
111
- continuousActions [ 0 ] = ( int ) obs [ 0 ] ;
112
- heuristicCalls ++ ;
113
- }
114
- }
115
-
116
- public class TestSensor : ISensor
117
- {
118
- public string sensorName ;
119
- public int numWriteCalls ;
120
- public int numCompressedCalls ;
121
- public int numResetCalls ;
122
- public SensorCompressionType compressionType = SensorCompressionType . None ;
123
-
124
- public TestSensor ( string n )
125
- {
126
- sensorName = n ;
127
- }
128
-
129
- public int [ ] GetObservationShape ( )
130
- {
131
- return new [ ] { 0 } ;
132
- }
133
-
134
- public ObservationSpec GetObservationSpec ( )
135
- {
136
- return ObservationSpec . Vector ( 0 ) ;
137
- }
138
-
139
- public int Write ( ObservationWriter writer )
140
- {
141
- numWriteCalls ++ ;
142
- // No-op
143
- return 0 ;
144
- }
145
-
146
- public byte [ ] GetCompressedObservation ( )
147
- {
148
- numCompressedCalls ++ ;
149
- return new byte [ ] { 0 } ;
150
- }
151
-
152
- public SensorCompressionType GetCompressionType ( )
153
- {
154
- return compressionType ;
155
- }
156
-
157
- public string GetName ( )
158
- {
159
- return sensorName ;
160
- }
161
-
162
- public void Update ( ) { }
163
-
164
- public void Reset ( )
165
- {
166
- numResetCalls ++ ;
167
- }
168
- }
169
16
170
17
[ TestFixture ]
171
18
public class EditModeTestGeneration
@@ -175,7 +22,8 @@ public void SetUp()
175
22
{
176
23
if ( Academy . IsInitialized )
177
24
{
178
- Academy . Instance . Di se }
25
+ Academy . Instance . Dispose ( ) ;
26
+ }
179
27
}
180
28
181
29
[ Test ]
0 commit comments