Skip to content

Commit d94193d

Browse files
authored
Add support for complex64 data type in parseDtypeParam function (#8083)
* Add support for complex64 data type in `parseDtypeParam` function * Add parse complex64 test to operation_mapper_test.ts
1 parent b8a0023 commit d94193d

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

tfjs-converter/src/operations/operation_mapper.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,9 @@ export function parseDtypeParam(value: string|tensorflow.DataType): DataType {
499499
return 'float32';
500500
case tensorflow.DataType.DT_STRING:
501501
return 'string';
502+
case tensorflow.DataType.DT_COMPLEX64:
503+
case tensorflow.DataType.DT_COMPLEX128:
504+
return 'complex64';
502505
default:
503506
// Unknown dtype error will happen at runtime (instead of parse time),
504507
// since these nodes might not be used by the actual subgraph execution.

tfjs-converter/src/operations/operation_mapper_test.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,12 @@ const SIMPLE_MODEL: tensorflow.IGraphDef = {
155155
input: ['BiasAdd'],
156156
attr: {DstT: {type: tensorflow.DataType.DT_HALF}}
157157
},
158+
{
159+
name: 'Cast4',
160+
op: 'Cast',
161+
input: ['BiasAdd'],
162+
attr: {DstT: {type: tensorflow.DataType.DT_COMPLEX64}}
163+
}
158164
],
159165
library: {
160166
function: [
@@ -310,7 +316,7 @@ describe('operationMapper without signature', () => {
310316
it('should find the graph output nodes', () => {
311317
expect(convertedGraph.outputs.map(node => node.name)).toEqual([
312318
'Fill', 'Squeeze', 'Squeeze2', 'Split', 'LogicalNot',
313-
'FusedBatchNorm', 'Cast2', 'Cast3'
319+
'FusedBatchNorm', 'Cast2', 'Cast3', 'Cast4'
314320
]);
315321
});
316322

@@ -324,7 +330,7 @@ describe('operationMapper without signature', () => {
324330
expect(Object.keys(convertedGraph.nodes)).toEqual([
325331
'image_placeholder', 'Const', 'Shape', 'Value', 'Fill', 'Conv2D',
326332
'BiasAdd', 'Cast', 'Squeeze', 'Squeeze2', 'Split', 'LogicalNot',
327-
'FusedBatchNorm', 'Cast2', 'Cast3'
333+
'FusedBatchNorm', 'Cast2', 'Cast3', 'Cast4'
328334
]);
329335
});
330336
});
@@ -447,6 +453,10 @@ describe('operationMapper without signature', () => {
447453
expect(convertedGraph.nodes['Cast'].attrParams['dtype'].value)
448454
.toEqual('int32');
449455
});
456+
it('should map params with complex64 dtype', () => {
457+
expect(convertedGraph.nodes['Cast4'].attrParams['dtype'].value)
458+
.toEqual('complex64');
459+
});
450460
});
451461
});
452462
});
@@ -486,7 +496,7 @@ describe('operationMapper with signature', () => {
486496
expect(Object.keys(convertedGraph.nodes)).toEqual([
487497
'image_placeholder', 'Const', 'Shape', 'Value', 'Fill', 'Conv2D',
488498
'BiasAdd', 'Cast', 'Squeeze', 'Squeeze2', 'Split', 'LogicalNot',
489-
'FusedBatchNorm', 'Cast2', 'Cast3'
499+
'FusedBatchNorm', 'Cast2', 'Cast3', 'Cast4'
490500
]);
491501
});
492502
});
@@ -552,6 +562,10 @@ describe('operationMapper with signature', () => {
552562
expect(convertedGraph.nodes['Cast3'].attrParams['dtype'].value)
553563
.toEqual('float32');
554564
});
565+
it('should map params with complex64 dtype', () => {
566+
expect(convertedGraph.nodes['Cast4'].attrParams['dtype'].value)
567+
.toEqual('complex64');
568+
});
555569
});
556570
});
557571
});

0 commit comments

Comments
 (0)