@@ -502,16 +502,8 @@ def hybridize(self, active=True, **kwargs):
502
502
----------
503
503
active : bool, default True
504
504
Whether to turn hybrid on or off.
505
- static_alloc : bool, default False
506
- Statically allocate memory to improve speed. Memory usage may increase.
507
- static_shape : bool, default False
508
- Optimize for invariant input shapes between iterations. Must also
509
- set static_alloc to True. Change of input shapes is still allowed
510
- but slower.
511
- forward_bulk_size : int, default 15
512
- Segment size of bulk execution during forward pass.
513
- backward_bulk_size : int, default 15
514
- Segment size of bulk execution during backward pass.
505
+ **kwargs : string
506
+ Additional flags for hybridized operator.
515
507
"""
516
508
for cld in self ._children .values ():
517
509
cld .hybridize (active , ** kwargs )
@@ -704,7 +696,7 @@ def __init__(self, prefix=None, params=None):
704
696
self ._out_format = None
705
697
self ._in_format = None
706
698
self ._active = False
707
- self ._flags = []
699
+ self ._flags = {}
708
700
709
701
def __setattr__ (self , name , value ):
710
702
"""Registers parameters."""
@@ -731,43 +723,39 @@ def _get_graph(self, *args):
731
723
return self ._cached_graph
732
724
733
725
def _build_cache (self , * args ):
734
- data , out = self ._get_graph (* args )
735
- data_names = {data .name : i for i , data in enumerate (data )}
736
- params = self .collect_params ()
737
- input_names = out .list_inputs ()
726
+ inputs , out = self ._get_graph (* args )
727
+ input_names = [i .name for i in inputs ]
738
728
729
+ params = self .collect_params ()
739
730
param_names = set (params .keys ())
740
- expected_names = set (input_names )
731
+ expected_names = set (out . list_inputs () )
741
732
for name in expected_names :
742
- assert name in param_names or name in data_names , \
733
+ assert name in param_names or name in input_names , \
743
734
"Unknown input to HybridBlock: %s" % name
744
735
745
- used_data_names = [i for i in data_names if i in expected_names ]
746
- if len (used_data_names ) != len (data_names ):
747
- unused = ', ' .join (['%d-th' % i for name , i in data_names . items ( )
736
+ used_input_names = [i for i in input_names if i in expected_names ]
737
+ if len (used_input_names ) != len (input_names ):
738
+ unused = ', ' .join (['%d-th' % i for i , name in enumerate ( input_names )
748
739
if name not in expected_names ])
749
740
warnings .warn ("The %s input to HybridBlock is not used by any "
750
741
"computation. Is this intended?" % unused , stacklevel = 4 )
751
742
752
- used_param_names = [ i for i in param_names if i in expected_names ]
743
+ used_param_names = set ( i for i in param_names if i in expected_names )
753
744
if len (used_param_names ) != len (param_names ):
754
- unused = ', ' .join (list (param_names - set ( used_param_names ) ))
745
+ unused = ', ' .join (list (param_names - used_param_names ))
755
746
warnings .warn ("Parameter %s is not used by any computation. "
756
747
"Is this intended?" % unused , stacklevel = 4 )
757
748
758
- data_indices = []
759
- param_indices = []
760
- self ._cached_op_args = []
761
- for i , name in enumerate (input_names ):
762
- if name in data_names :
763
- data_indices .append (i )
764
- self ._cached_op_args .append ((True , data_names [name ]))
765
- else :
766
- param_indices .append (i )
767
- self ._cached_op_args .append ((False , params [name ]))
768
- flags = [('data_indices' , data_indices ), ('param_indices' , param_indices )] + \
769
- self ._flags
770
- self ._cached_op = ndarray .CachedOp (out , flags )
749
+ used_params = {k : params [k ] for k in used_param_names }
750
+ try :
751
+ param_dict = {k : v .list_data () for k , v in used_params .items ()}
752
+ except DeferredInitializationError :
753
+ self ._deferred_infer_shape (* args )
754
+ for i in used_params .values ():
755
+ i ._finish_deferred_init ()
756
+ param_dict = {k : v .list_data () for k , v in used_params .items ()}
757
+
758
+ self ._cached_op = ndarray .CachedOp (out , self ._flags , input_names , param_dict )
771
759
772
760
def _deferred_infer_shape (self , * args ):
773
761
try :
@@ -783,19 +771,7 @@ def _call_cached_op(self, *args):
783
771
784
772
args , fmt = _flatten (args , "input" )
785
773
assert fmt == self ._in_format , "Invalid input format"
786
- try :
787
- cargs = [args [i ] if is_arg else i .data ()
788
- for is_arg , i in self ._cached_op_args ]
789
- except DeferredInitializationError :
790
- self ._deferred_infer_shape (* args )
791
- cargs = []
792
- for is_arg , i in self ._cached_op_args :
793
- if is_arg :
794
- cargs .append (args [i ])
795
- else :
796
- i ._finish_deferred_init ()
797
- cargs .append (i .data ())
798
- out = self ._cached_op (* cargs )
774
+ out = self ._cached_op (* args )
799
775
if isinstance (out , NDArray ):
800
776
out = [out ]
801
777
return _regroup (out , self ._out_format )[0 ]
@@ -816,7 +792,7 @@ def register_child(self, block, name=None):
816
792
817
793
def hybridize (self , active = True , ** kwargs ):
818
794
self ._active = active
819
- self ._flags = list ( kwargs .items () )
795
+ self ._flags = kwargs .items ()
820
796
self ._clear_cached_op ()
821
797
if active and self ._forward_hooks or self ._forward_pre_hooks :
822
798
warnings .warn ('"{}" is being hybridized while still having forward hook/pre-hook. '
0 commit comments