32
32
from pytensor .tensor .rewriting .basic import (
33
33
register_canonicalize ,
34
34
register_specialize ,
35
- register_stabilize ,
36
35
register_useless ,
37
36
topo_constant_folding ,
38
37
)
@@ -749,51 +748,43 @@ def apply(self, fgraph):
749
748
pytensor .compile .mode .optdb .register ("UnShapeOpt" , UnShapeOptimizer (), position = 10 )
750
749
751
750
752
- def local_reshape_chain (op ):
753
- @node_rewriter ([op ])
754
- def f (fgraph , node ):
755
- """
756
- Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
757
-
758
- """
759
- if not check_chain (node , op , op ):
760
- return False
761
-
762
- # TODO: this can permit a failing program to run by eliminating
763
- # the lower reshape
764
- rval = node .op (node .inputs [0 ].owner .inputs [0 ], node .inputs [1 ])
765
-
766
- # Copy over stacktrace from previous output node, as any error
767
- # in new computational graph would have been caused by last op
768
- # in the old computational graph.
769
- copy_stack_trace (node .outputs , rval )
770
-
771
- # It might happen that the desired output of this node has a
772
- # broadcastable pattern that does not match that of 'rval'. This is
773
- # when originally, we were able to figure out that one of the
774
- # dimensions of the reshape is one, but some other transformation
775
- # replaced the shape by one for which this cannot be guessed.
776
- # We should try to figure out why we lost the information about this
777
- # constant value... but in the meantime, better not apply this
778
- # rewrite.
779
- if rval .type .ndim == node .outputs [0 ].type .ndim and all (
780
- s1 == s2
781
- for s1 , s2 in zip (rval .type .shape , node .outputs [0 ].type .shape )
782
- if s1 == 1 or s2 == 1
783
- ):
784
- return [rval ]
785
- else :
786
- return False
787
-
788
- return f
751
+ @register_canonicalize ("shape_unsafe" )
752
+ @register_specialize ("shape_unsafe" )
753
+ @node_rewriter ([Reshape ])
754
+ def local_reshape_chain (fgraph , node ):
755
+ """
756
+ Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2)
789
757
758
+ """
759
+ if not check_chain (node , Reshape , Reshape ):
760
+ return False
790
761
791
- register_canonicalize (local_reshape_chain (Reshape ), name = "local_reshape_chain" )
762
+ rval = node .op (node .inputs [0 ].owner .inputs [0 ], node .inputs [1 ])
763
+
764
+ # Copy over stacktrace from previous output node, as any error
765
+ # in new computational graph would have been caused by last op
766
+ # in the old computational graph.
767
+ copy_stack_trace (node .outputs , rval )
768
+
769
+ # It might happen that the desired output of this node has a
770
+ # broadcastable pattern that does not match that of 'rval'. This is
771
+ # when originally, we were able to figure out that one of the
772
+ # dimensions of the reshape is one, but some other transformation
773
+ # replaced the shape by one for which this cannot be guessed.
774
+ # We should try to figure out why we lost the information about this
775
+ # constant value... but in the meantime, better not apply this
776
+ # rewrite.
777
+ if rval .type .ndim == node .outputs [0 ].type .ndim and all (
778
+ s1 == s2
779
+ for s1 , s2 in zip (rval .type .shape , node .outputs [0 ].type .shape )
780
+ if s1 == 1 or s2 == 1
781
+ ):
782
+ return [rval ]
792
783
793
784
794
- @register_useless
795
- @register_canonicalize
796
- @register_stabilize
785
+ @register_useless ( "shape_unsafe" )
786
+ @register_canonicalize ( "shape_unsafe" )
787
+ @register_specialize ( "shape_unsafe" )
797
788
@node_rewriter ([Reshape ])
798
789
def local_useless_reshape (fgraph , node ):
799
790
"""Remove two kinds of useless `Reshape`.
@@ -802,24 +793,17 @@ def local_useless_reshape(fgraph, node):
802
793
- Remove `Reshape` when reshaping to the shape of the input.
803
794
804
795
"""
805
- inp = node .inputs [0 ]
806
- output = node .outputs [0 ]
807
- output_shape = node .inputs [1 ]
796
+ inp , output_shape = node .inputs
797
+ [output ] = node .outputs
808
798
809
799
if inp .type .ndim != output .type .ndim :
810
800
return False
811
801
812
802
# Simple case: both input and output have a single dimension.
813
- # TODO FIXME XXX: This could hide errors if the user provides inconsistent
814
- # shapes.
815
803
if (
816
804
inp .type .ndim == 1
817
805
and output .type .ndim == 1
818
- and all (
819
- s1 == s2
820
- for s1 , s2 in zip (inp .type .shape , output .type .shape )
821
- if s1 == 1 or s2 == 1
822
- )
806
+ and inp .type .broadcastable == output .type .broadcastable
823
807
):
824
808
return [inp ]
825
809
@@ -832,8 +816,15 @@ def local_useless_reshape(fgraph, node):
832
816
833
817
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
834
818
# broadcastable and constant dimensions
835
- if output_shape .owner and isinstance (output_shape .owner .op , MakeVector ):
836
- output_shape_is = output_shape .owner .inputs
819
+ if isinstance (output_shape , Constant ) or (
820
+ output_shape .owner and isinstance (output_shape .owner .op , MakeVector )
821
+ ):
822
+ if isinstance (output_shape , Constant ):
823
+ output_shape_is = [
824
+ as_tensor_variable (dim , ndim = 0 ) for dim in output_shape .data
825
+ ]
826
+ else :
827
+ output_shape_is = output_shape .owner .inputs
837
828
838
829
shape_feature = getattr (fgraph , "shape_feature" , None )
839
830
@@ -865,9 +856,9 @@ def local_useless_reshape(fgraph, node):
865
856
shape_match [dim ] = True
866
857
continue
867
858
868
- # Match 1 if input.type.shape[dim] == 1
859
+ # Match constant if input.type.shape[dim] == constant
869
860
cst_outshp_i = extract_constant (outshp_i , only_process_constants = 1 )
870
- if inp .type .shape [dim ] == 1 and cst_outshp_i == 1 :
861
+ if inp .type .shape [dim ] == cst_outshp_i :
871
862
shape_match [dim ] = True
872
863
continue
873
864
@@ -881,17 +872,18 @@ def local_useless_reshape(fgraph, node):
881
872
if shape_feature :
882
873
inpshp_i = shape_feature .get_shape (inp , dim )
883
874
if inpshp_i == outshp_i or (
884
- extract_constant (inpshp_i , only_process_constants = 1 )
885
- == extract_constant (outshp_i , only_process_constants = 1 )
875
+ extract_constant (inpshp_i , only_process_constants = True )
876
+ == extract_constant (outshp_i , only_process_constants = True )
886
877
):
887
878
shape_match [dim ] = True
888
879
continue
889
880
890
- if all (shape_match ) and nb_m1 <= 1 :
881
+ if nb_m1 <= 1 and all (shape_match ):
882
+ return [inp ]
883
+
884
+ if (nb_m1 == 0 ) and (shape_match .count (False ) == output .type .ndim - 1 ):
891
885
return [inp ]
892
886
893
- # TODO later: if all the shapes except one match, we may want to
894
- # consider it useless as well, like we do in the 1-dim case.
895
887
return False
896
888
897
889
@@ -910,9 +902,8 @@ def local_reshape_to_dimshuffle(fgraph, node):
910
902
-> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
911
903
"""
912
904
op = node .op
913
- inp = node .inputs [0 ]
914
- output = node .outputs [0 ]
915
- output_shape = node .inputs [1 ]
905
+ inp , output_shape = node .inputs
906
+ [output ] = node .outputs
916
907
917
908
dimshuffle_new_order = []
918
909
new_output_shape = []
@@ -944,7 +935,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
944
935
945
936
946
937
@register_canonicalize
947
- @register_stabilize
938
+ @register_specialize
948
939
@node_rewriter ([Reshape ])
949
940
def local_reshape_lift (fgraph , node ):
950
941
"""
0 commit comments