forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSymbol.pm
1449 lines (1257 loc) · 41 KB
/
Symbol.pm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
package AI::MXNet::Symbol;
=head1 NAME
AI::MXNet::Symbol - Symbolic interface of MXNet.
=cut
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::Symbol::Base;
use AI::MXNet::Types;
use Mouse;
use AI::MXNet::Function::Parameters;
use overload
'""' => \&stringify,
'+' => \&add,
'-' => \&subtract,
'*' => \&multiply,
'/' => \÷,
'/=' => \&idivide,
'**' => \&power,
'%' => \&mod,
'==' => \&equal,
'!=' => \¬_equal,
'>' => \&greater,
'>=' => \&greater_equal,
'<' => \&lesser,
'<=' => \&lesser_equal,
'&{}' => sub { my $self = shift; sub { $self->call(@_) } },
'@{}' => sub { my $self = shift; [map { $self->slice($_) } @{ $self->list_outputs }] };
extends 'AI::MXNet::Symbol::Base';
has 'handle' => (is => 'rw', isa => 'SymbolHandle', required => 1);
sub DEMOLISH
{
check_call(AI::NNVMCAPI::SymbolFree(shift->handle));
}
method STORABLE_freeze($cloning)
{
return $self->tojson();
}
method STORABLE_thaw($cloning, $json)
{
my $handle = check_call(
AI::MXNetCAPI::SymbolCreateFromJSON(
$json
)
);
$self->handle($handle);
}
method stringify($other=, $reverse=)
{
my $name = $self->name;
sprintf("<%s %s>", ref($self), $name ? $name : 'Grouped');
}
method add(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_Plus _PlusScalar/
);
}
method subtract(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_Minus _MinusScalar _RMinusScalar/,
$reverse
);
}
method multiply(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_Mul _MulScalar/
);
}
method divide(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_Div _DivScalar _RDivScalar/,
$reverse
);
}
method power(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_Power _PowerScalar _RPowerScalar/,
$reverse
);
}
method equal(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_equal _equal_scalar/
);
}
method not_equal(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_not_equal _not_equal_scalar/
);
}
method greater(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_greater _greater_scalar _lesser_scalar/,
$reverse
);
}
method greater_equal(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_greater_equal _greater_equal_scalar _lesser_equal_scalar/,
$reverse
);
}
method lesser(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_lesser _lesser_scalar _greater_scalar/,
$reverse
);
}
method lesser_equal(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_lesser_equal _lesser_equal_scalar _greater_equal_scalar/,
$reverse
);
}
method true_divide(AI::MXNet::Symbol|Num $other, $reverse=)
{
return $self->divide($other, $reverse);
}
method mod(AI::MXNet::Symbol|Num $other, $reverse=)
{
return _ufunc_helper(
$self,
$other,
qw/_Mod _ModScalar _RModScalar/,
$reverse
);
}
method maximum(AI::MXNet::Symbol|Num $other)
{
return _ufunc_helper(
$self,
$other,
qw/_Maximum _MaximumScalar/
);
}
method minimum(AI::MXNet::Symbol|Num $other)
{
return _ufunc_helper(
$self,
$other,
qw/_Minimum _MinimumScalar/
);
}
method hypot(AI::MXNet::Symbol|Num $other)
{
return _ufunc_helper(
$self,
$other,
qw/_Hypot _HypotScalar/
);
}
method deepcopy()
{
my $handle = check_call(AI::MXNetCAPI::SymbolCopy($self->handle));
return __PACKAGE__->new(handle => $handle);
}
method call(@args)
{
my $s = $self->deepcopy();
$s->_compose(@args);
return $s;
}
method slice(Str|Index $index)
{
## __getitem__ tie needs to die
if(not find_type_constraint('Index')->check($index))
{
my $i = 0;
my $idx;
for my $name (@{ $self->list_outputs() })
{
if($name eq $index)
{
if(defined $idx)
{
confess(qq/There are multiple outputs with name "$index"/);
}
$idx = $i;
}
$i++;
}
confess(qq/Cannot find output that matches name "$index"/) unless defined $idx;
$index = $idx;
}
elsif($index >= @{ $self->list_outputs() })
{
confess("Index: [$index] is outside of the range of the symbol: $self outputs");
}
my $handle = check_call(AI::MXNetCAPI::SymbolGetOutput($self->handle, $index));
return __PACKAGE__->new(handle => $handle);
}
=head2 name
Get name string from the symbol, this function only works for non-grouped symbol.
Returns
-------
value : str
The name of this symbol, returns None for grouped symbol.
=cut
method name()
{
my ($name, $success) = check_call(AI::MXNetCAPI::SymbolGetName($self->handle));
return $success ? $name : undef;
}
=head2 attr
Get an attribute string from the symbol, this function only works for non-grouped symbol.
Parameters
----------
key : str
The key to get attribute from.
Returns
-------
value : str
The attribute value of the key, returns None if attribute do not exist.
=cut
method attr(Str $key)
{
my ($attr, $success) = check_call(
AI::MXNetCAPI::SymbolGetAttr($self->handle, $key)
);
return $success ? $attr : undef;
}
=head2 list_attr
Get all attributes from the symbol.
Returns
-------
ret : hash ref of str to str
a dicitonary mapping attribute keys to values
=cut
method list_attr()
{
my %ret;
my @attrs = @{ check_call(AI::MXNetCAPI::SymbolListAttrShallow($self->handle)) };
while(@attrs)
{
my $k = shift(@attrs);
my $v = shift(@attrs);
$ret{ $k } = $v;
}
return \%ret;
}
=head2 attr_dict
Recursively get all attributes from the symbol and its childrens
Returns
-------
ret : hash ref of str to hash ref.
Returns a dict whose keys are names of the symbol and its children.
Values of the returned dict are dictionaries that map attribute keys to values.
=cut
method attr_dict()
{
my %ret;
my @attrs = @{ check_call(AI::MXNetCAPI::SymbolListAttr($self->handle)) };
my $size = @attrs/2;
for (my $i = 0; $i < $size; $i++)
{
my ($name, $key) = split(/\$/, $attrs[$i*2]);
my $val = $attrs[$i*2+1];
$ret{ $name }{ $key } = $val;
}
return \%ret;
}
method _set_attr(Str @args)
{
my %kwargs = @args;
while(my ($key, $val) = each(%kwargs))
{
check_call(
AI::MXNetCAPI::SymbolSetAttr(
$self->handle, $key, $val
)
);
}
}
=head2 get_internals
Get a new grouped symbol whose output contains all the internal outputs of this symbol.
Returns
-------
sgroup : AI::MXNet::Symbol
The internal symbol of the symbol.
=cut
method get_internals()
{
my $handle = check_call(AI::MXNetCAPI::SymbolGetInternals($self->handle));
return __PACKAGE__->new(handle => $handle);
}
=head2 get_children
Get a new grouped symbol whose output contains
inputs to output nodes of the original symbol
Returns
-------
sgroup : Symbol or undef
The children of the head node. If the symbol has no
inputs undef will be returned.
=cut
method get_children()
{
my $handle = check_call(AI::MXNetCAPI::SymbolGetChildren($self->handle));
my $ret = __PACKAGE__->new(handle => $handle);
return undef unless @{ $ret->list_outputs };
return $ret;
}
=head2 list_arguments
List all the arguments in the symbol.
Returns
-------
args : array ref of strings
=cut
method list_arguments()
{
return scalar(check_call(AI::MXNetCAPI::SymbolListArguments($self->handle)));
}
=head2 list_outputs()
List all outputs in the symbol.
Returns
-------
$out : array ref of strings.
=cut
method list_outputs()
{
return scalar(check_call(AI::MXNetCAPI::SymbolListOutputs($self->handle)));
}
=head2 list_auxiliary_states()
List all auxiliary states in the symbol.
Returns
-------
aux_states : array ref of string
List the names of the auxiliary states.
Notes
-----
Auxiliary states are special states of symbols that do not corresponds to an argument,
and do not have gradient. But still be useful for the specific operations.
A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm.
Most operators do not have Auxiliary states.
=cut
method list_auxiliary_states()
{
return scalar(check_call(AI::MXNetCAPI::SymbolListAuxiliaryStates($self->handle)));
}
=head2 list_inputs
Lists all arguments and auxiliary states of this Symbol.
Returns
-------
inputs : array ref of str
List of all inputs.
Examples
--------
>>> my $bn = mx->sym->BatchNorm(name=>'bn');
=cut
method list_inputs()
{
return scalar(check_call(AI::NNVMCAPI::SymbolListInputNames($self->handle, 0)));
}
=head2 infer_type
Infer the type of outputs and arguments of given known types of arguments.
User can either pass in the known types in positional way or keyword argument way.
Tuple of Nones is returned if there is not enough information passed in.
An error will be raised if there is inconsistency found in the known types passed in.
Parameters
----------
args : Array
Provide type of arguments in a positional way.
Unknown type can be marked as None
kwargs : Hash ref, must ne ssupplied as as sole argument to the method.
Provide keyword arguments of known types.
Returns
-------
arg_types : array ref of Dtype or undef
List of types of arguments.
The order is in the same order as list_arguments()
out_types : array ref of Dtype or undef
List of types of outputs.
The order is in the same order as list_outputs()
aux_types : array ref of Dtype or undef
List of types of outputs.
The order is in the same order as list_auxiliary()
=cut
method infer_type(Str|Undef @args)
{
my ($positional_arguments, $kwargs, $kwargs_order) = _parse_arguments("Dtype", @args);
my $sdata = [];
my $keys = [];
if(@$positional_arguments)
{
@{ $sdata } = map { defined($_) ? DTYPE_STR_TO_MX->{ $_ } : -1 } @{ $positional_arguments };
}
else
{
@{ $keys } = @{ $kwargs_order };
@{ $sdata } = map { DTYPE_STR_TO_MX->{ $_ } } @{ $kwargs }{ @{ $kwargs_order } };
}
my ($arg_type, $out_type, $aux_type, $complete) = check_call(AI::MXNetCAPI::SymbolInferType(
$self->handle,
scalar(@{ $sdata }),
$keys,
$sdata
)
);
if($complete)
{
return (
[ map { DTYPE_MX_TO_STR->{ $_ } } @{ $arg_type }],
[ map { DTYPE_MX_TO_STR->{ $_ } } @{ $out_type }],
[ map { DTYPE_MX_TO_STR->{ $_ } } @{ $aux_type }]
);
}
else
{
return (undef, undef, undef);
}
}
=head2 infer_shape
Infer the shape of outputs and arguments of given known shapes of arguments.
User can either pass in the known shapes in positional way or keyword argument way.
Tuple of Nones is returned if there is not enough information passed in.
An error will be raised if there is inconsistency found in the known shapes passed in.
Parameters
----------
*args :
Provide shape of arguments in a positional way.
Unknown shape can be marked as undef
**kwargs :
Provide keyword arguments of known shapes.
Returns
-------
arg_shapes : array ref of Shape or undef
List of shapes of arguments.
The order is in the same order as list_arguments()
out_shapes : array ref of Shape or undef
List of shapes of outputs.
The order is in the same order as list_outputs()
aux_shapes : array ref of Shape or undef
List of shapes of outputs.
The order is in the same order as list_auxiliary()
=cut
method infer_shape(Maybe[Str|Shape] @args)
{
my @res = $self->_infer_shape_impl(0, @args);
if(not defined $res[1])
{
my ($arg_shapes) = $self->_infer_shape_impl(1, @args);
my $arg_names = $self->list_arguments;
my @unknowns;
zip(sub {
my ($name, $shape) = @_;
if(not ref $shape or not @$shape or not product(@$shape))
{
if(@unknowns >= 10)
{
$unknowns[10] = '...';
}
else
{
my @shape = eval { @$shape };
push @unknowns, "$name @shape";
}
}
}, $arg_names, $arg_shapes);
AI::MXNet::Logging->warning(
"Cannot decide shape for the following arguments "
."(0s in shape means unknown dimensions). "
."Consider providing them as input:\n\t"
."\n\t"
.join(", ", @unknowns)
);
}
return @res;
}
=head2 infer_shape_partial
Partially infer the shape. The same as infer_shape, except that the partial
results can be returned.
=cut
method infer_shape_partial(Maybe[Str|Shape] @args)
{
$self->_infer_shape_impl(1, @args)
}
# The actual implementation for calling shape inference API.
method _infer_shape_impl(Maybe[Str|Shape] @args)
{
my $partial = shift(@args);
my ($positional_arguments, $kwargs, $kwargs_order) = _parse_arguments("Shape", @args);
my $sdata = [];
my $indptr = [0];
my $keys = [];
if(@{ $positional_arguments })
{
for my $shape (grep { defined } @{ $positional_arguments })
{
push @{ $sdata }, @{ $shape };
push @{ $indptr }, scalar(@{ $sdata });
}
}
{
for my $k (@{ $kwargs_order })
{
push @{ $keys }, $k;
push @{ $sdata }, @{ $kwargs->{ $k } };
push @{ $indptr }, scalar(@{ $sdata });
}
}
my $infer_func = $partial ? \&AI::MXNetCAPI::SymbolInferShapePartial : \&AI::MXNetCAPI::SymbolInferShape;
my ($arg_shapes, $out_shapes, $aux_shapes, $complete) = check_call(
$infer_func->(
$self->handle,
scalar(@{ $indptr }) - 1,
$keys,
$indptr,
$sdata,
)
);
if($complete)
{
return $arg_shapes, $out_shapes, $aux_shapes;
}
else
{
return (undef, undef, undef);
}
}
=head2 debug_str
The debug string.
Returns
-------
debug_str : string
Debug string of the symbol.
=cut
method debug_str()
{
return scalar(check_call(AI::MXNetCAPI::SymbolPrint($self->handle)));
}
=head2 save
Save the symbol into a file.
You can also use Storable to do the job if you only work with Perl.
The advantage of load/save is the file is language agnostic.
This means the file saved using save can be loaded by other language binding of mxnet.
You also get the benefit being able to directly load/save from cloud storage(S3, HDFS)
Parameters
----------
fname : str
The name of the file
- s3://my-bucket/path/my-s3-symbol
- hdfs://my-bucket/path/my-hdfs-symbol
- /path-to/my-local-symbol
See Also
--------
load : Used to load symbol from file.
=cut
method save(Str $fname)
{
check_call(AI::MXNetCAPI::SymbolSaveToFile($self->handle, $fname));
}
=head2 tojson
Save the symbol into a JSON string.
See Also
--------
load_json : Used to load symbol from JSON string.
=cut
method tojson()
{
return scalar(check_call(AI::MXNetCAPI::SymbolSaveToJSON($self->handle)));
}
method _get_ndarray_inputs(
Str $arg_key,
HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray] $args,
ArrayRef[Str] $arg_names,
Bool $allow_missing=0
)
{
my ($arg_handles, $arg_arrays) = ([], []);
if(ref $args eq 'ARRAY')
{
confess("Length of $arg_key do not match number of arguments")
unless @$args == @$arg_names;
@{ $arg_handles } = map { $_->handle } @{ $args };
$arg_arrays = $args;
}
else
{
my %tmp = ((map { $_ => undef } @$arg_names), %$args);
if(not $allow_missing and grep { not defined } values %tmp)
{
my ($missing) = grep { not defined $tmp{ $_ } } (keys %tmp);
confess("key $missing is missing in $arg_key");
}
for my $name (@$arg_names)
{
push @$arg_handles, defined($tmp{ $name }) ? $tmp{ $name }->handle : undef;
push @$arg_arrays, defined($tmp{ $name }) ? $tmp{ $name } : undef;
}
}
return ($arg_handles, $arg_arrays);
}
=head2 simple_bind
Bind current symbol to get an executor, allocate all the ndarrays needed.
Allows specifying data types.
This function will ask user to pass in ndarray of position
they like to bind to, and it will automatically allocate the ndarray
for arguments and auxiliary states that user did not specify explicitly.
Parameters
----------
:$ctx : AI::MXNet::Context
The device context the generated executor to run on.
:$grad_req: string
{'write', 'add', 'null'}, or list of str or dict of str to str, optional
Specifies how we should update the gradient to the args_grad.
- 'write' means everytime gradient is write to specified args_grad NDArray.
- 'add' means everytime gradient is add to the specified NDArray.
- 'null' means no action is taken, the gradient may not be calculated.
:$type_dict : hash ref of str->Dtype
Input type map, name->dtype
:$group2ctx : hash ref of string to AI::MXNet::Context
The mapping of the ctx_group attribute to the context assignment.
:$shapes : hash ref of str->Shape
Input shape map, name->shape
:$shared_arg_names : Maybe[ArrayRef[Str]]
The argument names whose 'NDArray' of shared_exec can be reused for initializing
the current executor.
:$shared_exec : Maybe[AI::MXNet::Executor]
The executor whose arg_arrays, arg_arrays, grad_arrays, and aux_arrays can be
reused for initializing the current executor.
:$shared_buffer : Maybe[HashRef[AI::MXNet::NDArray]]
The dict mapping argument names to the `NDArray` that can be reused for initializing
the current executor. This buffer will be checked for reuse if one argument name
of the current executor is not found in `shared_arg_names`.
Returns
-------
$executor : AI::MXNet::Executor
The generated Executor
=cut
method simple_bind(
AI::MXNet::Context :$ctx=AI::MXNet::Context->current_ctx,
GradReq|ArrayRef[GradReq]|HashRef[GradReq] :$grad_req='write',
Maybe[HashRef[Shape]] :$shapes=,
Maybe[HashRef[Dtype]] :$type_dict=,
Maybe[HashRef[AI::MXNet::Context]] :$group2ctx=,
Maybe[ArrayRef[Str]] :$shared_arg_names=,
Maybe[AI::MXNet::Executor] :$shared_exec=,
Maybe[HashRef[AI::MXNet::NDArray]] :$shared_buffer=
)
{
my $num_provided_arg_types;
my @provided_arg_type_names;
my @provided_arg_type_data;
if(defined $type_dict)
{
while(my ($k, $v) = each %{ $type_dict })
{
push @provided_arg_type_names, $k;
push @provided_arg_type_data, DTYPE_STR_TO_MX->{$v};
}
$num_provided_arg_types = @provided_arg_type_names;
}
my @provided_arg_shape_data;
# argument shape index in sdata,
# e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg
my @provided_arg_shape_idx = (0);
my @provided_arg_shape_names;
while(my ($k, $v) = each %{ $shapes//{} })
{
push @provided_arg_shape_names, $k;
push @provided_arg_shape_data, @{ $v };
push @provided_arg_shape_idx, scalar(@provided_arg_shape_data);
}
$num_provided_arg_types = @provided_arg_type_names;
my $provided_req_type_list_len = 0;
my @provided_grad_req_types;
my @provided_grad_req_names;
if(defined $grad_req)
{
if(not ref $grad_req)
{
push @provided_grad_req_types, $grad_req;
}
elsif(ref $grad_req eq 'ARRAY')
{
assert((@{ $grad_req } != 0), 'grad_req in simple_bind cannot be an empty list');
@provided_grad_req_types = @{ $grad_req };
$provided_req_type_list_len = @provided_grad_req_types;
}
elsif(ref $grad_req eq 'HASH')
{
assert((keys %{ $grad_req } != 0), 'grad_req in simple_bind cannot be an empty hash');
while(my ($k, $v) = each %{ $grad_req })
{
push @provided_grad_req_names, $k;
push @provided_grad_req_types, $v;
}
$provided_req_type_list_len = @provided_grad_req_types;
}
}
my $num_ctx_map_keys = 0;
my @ctx_map_keys;
my @ctx_map_dev_types;
my @ctx_map_dev_ids;
if(defined $group2ctx)
{
while(my ($k, $v) = each %{ $group2ctx })
{
push @ctx_map_keys, $k;
push @ctx_map_dev_types, $v->device_type_id;
push @ctx_map_dev_ids, $v->device_id;
}
$num_ctx_map_keys = @ctx_map_keys;
}
my @shared_arg_name_list;
if(defined $shared_arg_names)
{
@shared_arg_name_list = @{ $shared_arg_names };
}
my %shared_data;
if(defined $shared_buffer)
{
while(my ($k, $v) = each %{ $shared_buffer })
{
$shared_data{$k} = $v->handle;
}
}
my $shared_exec_handle = defined $shared_exec ? $shared_exec->handle : undef;
my (
$updated_shared_data,
$in_arg_handles,
$arg_grad_handles,
$aux_state_handles,
$exe_handle
);
eval {
($updated_shared_data, $in_arg_handles, $arg_grad_handles, $aux_state_handles, $exe_handle)
=
check_call(
AI::MXNetCAPI::ExecutorSimpleBind(
$self->handle,
$ctx->device_type_id,
$ctx->device_id,
$num_ctx_map_keys,
\@ctx_map_keys,
\@ctx_map_dev_types,
\@ctx_map_dev_ids,
$provided_req_type_list_len,
\@provided_grad_req_names,
\@provided_grad_req_types,
scalar(@provided_arg_shape_names),
\@provided_arg_shape_names,
\@provided_arg_shape_data,
\@provided_arg_shape_idx,
$num_provided_arg_types,
\@provided_arg_type_names,
\@provided_arg_type_data,
scalar(@shared_arg_name_list),
\@shared_arg_name_list,
defined $shared_buffer ? \%shared_data : undef,
$shared_exec_handle
)
);
};
if($@)
{
confess(
"simple_bind failed: Error: $@; Arguments: ".
Data::Dumper->new(
[$shapes//{}]
)->Purity(1)->Deepcopy(1)->Terse(1)->Dump
);
}
if(defined $shared_buffer)
{
while(my ($k, $v) = each %{ $updated_shared_data })
{
$shared_buffer->{$k} = AI::MXNet::NDArray->new(handle => $v);
}
}
my @arg_arrays = map { AI::MXNet::NDArray->new(handle => $_) } @{ $in_arg_handles };
my @grad_arrays = map { defined $_ ? AI::MXNet::NDArray->new(handle => $_) : undef } @{ $arg_grad_handles };
my @aux_arrays = map { AI::MXNet::NDArray->new(handle => $_) } @{ $aux_state_handles };
my $executor = AI::MXNet::Executor->new(
handle => $exe_handle,
symbol => $self,
ctx => $ctx,
grad_req => $grad_req,
group2ctx => $group2ctx
);
$executor->arg_arrays(\@arg_arrays);
$executor->grad_arrays(\@grad_arrays);
$executor->aux_arrays(\@aux_arrays);
return $executor;
}
=head2 bind
Bind current symbol to get an executor.
Parameters
----------
:$ctx : AI::MXNet::Context
The device context the generated executor to run on.
:$args : HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]
Input arguments to the symbol.
- If type is array ref of NDArray, the position is in the same order of list_arguments.
- If type is hash ref of str to NDArray, then it maps the name of arguments
to the corresponding NDArray.
- In either case, all the arguments must be provided.
:$args_grad : Maybe[HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]]
When specified, args_grad provide NDArrays to hold
the result of gradient value in backward.
- If type is array ref of NDArray, the position is in the same order of list_arguments.
- If type is hash ref of str to NDArray, then it maps the name of arguments
to the corresponding NDArray.
- When the type is hash ref of str to NDArray, users only need to provide the dict
for needed argument gradient.
Only the specified argument gradient will be calculated.
:$grad_req : {'write', 'add', 'null'}, or array ref of str or hash ref of str to str, optional
Specifies how we should update the gradient to the args_grad.
- 'write' means everytime gradient is write to specified args_grad NDArray.
- 'add' means everytime gradient is add to the specified NDArray.
- 'null' means no action is taken, the gradient may not be calculated.
:$aux_states : array ref of NDArray, or hash ref of str to NDArray, optional
Input auxiliary states to the symbol, only need to specify when
list_auxiliary_states is not empty.
- If type is array ref of NDArray, the position is in the same order of list_auxiliary_states
- If type is hash ref of str to NDArray, then it maps the name of auxiliary_states
to the corresponding NDArray,
- In either case, all the auxiliary_states need to be provided.
:$group2ctx : hash ref of string to AI::MXNet::Context
The mapping of the ctx_group attribute to the context assignment.
:$shared_exec : AI::MXNet::Executor
Executor to share memory with. This is intended for runtime reshaping, variable length
sequences, etc. The returned executor shares state with shared_exec, and should not be
used in parallel with it.
Returns
-------
$executor : AI::MXNet::Executor
The generated Executor
Notes
-----
Auxiliary states are special states of symbols that do not corresponds to an argument,
and do not have gradient. But still be useful for the specific operations.
A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm.