1
1
package bgu .spl .mics .application .objects ;
2
2
3
+ import java .util .ArrayDeque ;
3
4
import java .util .Collection ;
5
+ import java .util .Queue ;
6
+ import java .util .concurrent .ExecutorService ;
7
+ import java .util .concurrent .Executors ;
8
+ import java .util .concurrent .ThreadPoolExecutor ;
9
+ import java .util .concurrent .TimeUnit ;
10
+
11
+ import org .junit .experimental .theories .Theories ;
4
12
5
13
import bgu .spl .mics .MessageBusImpl ;
6
14
import bgu .spl .mics .application .services .GPUService ;
15
+ import bgu .spl .mics .application .messages .TestModelEvent ;
16
+ import bgu .spl .mics .application .messages .TickBroadcast ;
17
+ import bgu .spl .mics .application .messages .TrainModelEvent ;
18
+ import bgu .spl .mics .Callback ;
19
+ import bgu .spl .mics .Event ;
7
20
8
21
/**
9
22
* Passive object representing a single GPU.
10
23
* Add all the fields described in the assignment as private fields.
11
24
* Add fields and methods to this class as you see fit (including public methods and constructors).
12
25
*/
13
26
public class GPU {
27
+
14
28
/**
15
29
* Enum representing the type of the GPU.
16
30
*/
17
- public enum Type {RTX3090 , RTX2080 , GTX1080 }
31
+ public enum Type { RTX3090 , RTX2080 , GTX1080 }
18
32
19
33
20
34
// region According to assignment instructions
21
35
private Type type ;
22
- private Cluster cluster ;
23
- private Model model ;
36
+ private Model model ; // the model being taken care of currently
37
+ private static Cluster cluster = Cluster . getInstance () ;
24
38
// endregion According to assignment instructions
25
39
26
40
27
41
// region Added fields
28
- private GPUService service ;
42
+ // private GPUService service;
29
43
// private Collection<DataBatch> processedBatches;
30
- private final byte vRAM ;
31
- private int processedBatchesNum ;
44
+ private Queue <Event <Model >> modelEventsQueue ;
45
+ private final int trainingDelay ; // according to the type of the gpu
46
+ private final byte vRAM ; // according to the type of the gpu
47
+ private int storedProcessedBatchesNumber ;
48
+ private int ticksToTrainBatch ;
49
+ private boolean training ;
50
+ private boolean testing ;
32
51
// endregion Added fields
33
52
34
53
/**
@@ -58,7 +77,57 @@ public enum Type {RTX3090, RTX2080, GTX1080}
58
77
// }
59
78
60
79
61
- private Type typeFromString (String _type ) {
80
+ /**
81
+ * @return The type of the GPU
82
+ */
83
+ // public Type getType () {
84
+ // return type;
85
+ // }
86
+
87
+
88
+ // public void runService() { // TODO remove ?
89
+ // service.run();
90
+ // }
91
+
92
+
93
+ /**
94
+ * @return Reference to the corresponding GPUService
95
+ */
96
+ // public GPUService getService() {
97
+ // return service;
98
+ // }
99
+
100
+
101
+ // public Collection<DataBatch> getProcessedBatches() {
102
+ // return processedBatches;
103
+ // }
104
+
105
+
106
+ // region for serialization from json
107
+ public GPU (String _type ) {
108
+ this .type = typeFromString (_type );
109
+ model = null ;
110
+ storedProcessedBatchesNumber = 0 ;
111
+
112
+ switch (type ) {
113
+ case GTX1080 :
114
+ vRAM = 8 ;
115
+ trainingDelay = 4 ;
116
+ break ;
117
+ case RTX2080 :
118
+ vRAM = 16 ;
119
+ trainingDelay = 2 ;
120
+ break ;
121
+ default : // RTX3090:
122
+ vRAM = 32 ;
123
+ trainingDelay = 1 ;
124
+ }
125
+
126
+ modelEventsQueue = new ArrayDeque <>();
127
+ }
128
+
129
+
130
+ private Type typeFromString (String _type ) {
62
131
Type returnType ;
63
132
String uppercaseType = _type .toUpperCase ();
64
133
@@ -71,75 +140,75 @@ else if (uppercaseType == "RTX2080")
71
140
72
141
return returnType ;
73
142
}
143
+ // endregion for serialization from json
74
144
75
145
76
- /**
77
- * @return The type of the GPU
78
- */
79
- public Type getType () {
80
- return type ;
81
- }
146
+ public void gotTick () {
147
+ // TODO: treat the case of last tick??
82
148
149
+ // TODO: split to smaller functions (queries + actions)
150
+ if (training ) { // && storedProcessedBatchesNumber != 0) {
151
+ --ticksToTrainBatch ;
152
+
153
+ if (ticksToTrainBatch == 0 ) { // if finished training batch
154
+ ticksToTrainBatch = trainingDelay ; // reset the counter
155
+ finishTrainingBatch (); //
83
156
84
- // public void runService() { // TODO remove ?
85
- // service.run();
86
- // }
157
+ if (storedProcessedBatchesNumber == 0 ) {
158
+ training = false ; // TODO: maybe incorrect since someone could tell me to test a model while i haven't gotten all the batches from cpu yet
159
+ // TODO set Future to done / send more batches to cpu for processing
160
+ }
161
+ }
162
+ }
163
+ else if (testing ) {
164
+ // TODO
165
+ }
166
+ // else if (storedProcessedBatchesNumber != 0) {
167
+ // training = true;
168
+ // --ticksToTrainBatch;
169
+
170
+ // if (trainingDelay == 0) {
171
+ // gpu.finishTrainingBatch();
172
+
173
+ // if (gpu.getProcessedBatchesNum() == 0) {
174
+ // // TODO set Future to done
175
+ // }
176
+ // }
177
+ // }
178
+ }
87
179
88
180
89
- /**
90
- * @return Reference to the corresponding GPUService
91
- */
92
- public GPUService getService () {
93
- return service ;
181
+ public void gotModelEvent (Event <Model > modelEvent ) {
182
+ // TODO: put aside to wait for ticks
183
+ modelEventsQueue .add (modelEvent );
94
184
}
95
185
96
186
97
- // public Collection<DataBatch> getProcessedBatches( ) {
98
- // return processedBatches;
187
+ // public void gotModelToTest(TestModelEvent testModelEvent ) {
188
+ // // TODO: put aside to wait for ticks
99
189
// }
100
190
101
191
102
192
/**
103
193
* @return The amount of vRAM the GPU has
104
194
*/
105
- public byte getVRAM () {
106
- return vRAM ;
107
- }
195
+ // public byte getVRAM() {
196
+ // return vRAM;
197
+ // }
108
198
199
+
109
200
/**
110
201
* @return Number of processed batches currently in training
111
202
*/
112
- public int getProcessedBatchesNum () {
113
- return processedBatchesNum ;
114
- }
203
+ // public int getProcessedBatchesNum() {
204
+ // return storedProcessedBatchesNumber ;
205
+ // }
115
206
116
207
117
208
/**
118
209
* @inv processedBatchesNum >= 0
119
210
*/
120
211
public void finishTrainingBatch () {
121
- --processedBatchesNum ;
212
+ --storedProcessedBatchesNumber ;
122
213
}
123
-
124
-
125
- // region for serialization from json
126
- public GPU (String _type ) {
127
- this .type = typeFromString (_type );
128
- // this.service = new GPUService(_name, this);
129
-
130
- processedBatchesNum = 0 ;
131
-
132
- switch (type ) {
133
- case GTX1080 :
134
- vRAM = 8 ;
135
- break ;
136
- case RTX2080 :
137
- vRAM = 16 ;
138
- break ;
139
- default : // RTX3090:
140
- vRAM = 32 ;
141
- }
142
- }
143
- // endregion for serialization from json
144
-
145
- }
214
+ }
0 commit comments