@@ -155,6 +155,12 @@ const SIMPLE_MODEL: tensorflow.IGraphDef = {
155
155
input : [ 'BiasAdd' ] ,
156
156
attr : { DstT : { type : tensorflow . DataType . DT_HALF } }
157
157
} ,
158
+ {
159
+ name : 'Cast4' ,
160
+ op : 'Cast' ,
161
+ input : [ 'BiasAdd' ] ,
162
+ attr : { DstT : { type : tensorflow . DataType . DT_COMPLEX64 } }
163
+ }
158
164
] ,
159
165
library : {
160
166
function : [
@@ -310,7 +316,7 @@ describe('operationMapper without signature', () => {
310
316
it ( 'should find the graph output nodes' , ( ) => {
311
317
expect ( convertedGraph . outputs . map ( node => node . name ) ) . toEqual ( [
312
318
'Fill' , 'Squeeze' , 'Squeeze2' , 'Split' , 'LogicalNot' ,
313
- 'FusedBatchNorm' , 'Cast2' , 'Cast3'
319
+ 'FusedBatchNorm' , 'Cast2' , 'Cast3' , 'Cast4'
314
320
] ) ;
315
321
} ) ;
316
322
@@ -324,7 +330,7 @@ describe('operationMapper without signature', () => {
324
330
expect ( Object . keys ( convertedGraph . nodes ) ) . toEqual ( [
325
331
'image_placeholder' , 'Const' , 'Shape' , 'Value' , 'Fill' , 'Conv2D' ,
326
332
'BiasAdd' , 'Cast' , 'Squeeze' , 'Squeeze2' , 'Split' , 'LogicalNot' ,
327
- 'FusedBatchNorm' , 'Cast2' , 'Cast3'
333
+ 'FusedBatchNorm' , 'Cast2' , 'Cast3' , 'Cast4'
328
334
] ) ;
329
335
} ) ;
330
336
} ) ;
@@ -447,6 +453,10 @@ describe('operationMapper without signature', () => {
447
453
expect ( convertedGraph . nodes [ 'Cast' ] . attrParams [ 'dtype' ] . value )
448
454
. toEqual ( 'int32' ) ;
449
455
} ) ;
456
+ it ( 'should map params with complex64 dtype' , ( ) => {
457
+ expect ( convertedGraph . nodes [ 'Cast4' ] . attrParams [ 'dtype' ] . value )
458
+ . toEqual ( 'complex64' ) ;
459
+ } ) ;
450
460
} ) ;
451
461
} ) ;
452
462
} ) ;
@@ -486,7 +496,7 @@ describe('operationMapper with signature', () => {
486
496
expect ( Object . keys ( convertedGraph . nodes ) ) . toEqual ( [
487
497
'image_placeholder' , 'Const' , 'Shape' , 'Value' , 'Fill' , 'Conv2D' ,
488
498
'BiasAdd' , 'Cast' , 'Squeeze' , 'Squeeze2' , 'Split' , 'LogicalNot' ,
489
- 'FusedBatchNorm' , 'Cast2' , 'Cast3'
499
+ 'FusedBatchNorm' , 'Cast2' , 'Cast3' , 'Cast4'
490
500
] ) ;
491
501
} ) ;
492
502
} ) ;
@@ -552,6 +562,10 @@ describe('operationMapper with signature', () => {
552
562
expect ( convertedGraph . nodes [ 'Cast3' ] . attrParams [ 'dtype' ] . value )
553
563
. toEqual ( 'float32' ) ;
554
564
} ) ;
565
+ it ( 'should map params with complex64 dtype' , ( ) => {
566
+ expect ( convertedGraph . nodes [ 'Cast4' ] . attrParams [ 'dtype' ] . value )
567
+ . toEqual ( 'complex64' ) ;
568
+ } ) ;
555
569
} ) ;
556
570
} ) ;
557
571
} ) ;
0 commit comments