@@ -1731,11 +1731,12 @@ def from_tensor(cls, x):
17311731 return Data (name = str (x .op .name ), shape = (), batch_dim_axis = None , dtype = x .dtype .name , placeholder = x )
17321732
17331733 @classmethod
1734- def template_from_constant (cls , x , name , dtype = None , with_batch_dim = False ):
1734+ def template_from_constant (cls , x , name , dtype = None , shape = None , with_batch_dim = False ):
17351735 """
17361736 :param int|float|bool|numpy.ndarray x:
17371737 :param str name:
17381738 :param str|None dtype:
1739+ :param list[DimensionTag|int]|tuple[DimensionTag|int]|None shape: for verification, and defining dim tags
17391740 :param bool with_batch_dim:
17401741 :rtype: Data
17411742 """
@@ -1750,12 +1751,29 @@ def template_from_constant(cls, x, name, dtype=None, with_batch_dim=False):
17501751 elif isinstance (x , numpy .ndarray ):
17511752 dtype = str (x .dtype )
17521753 else :
1753- raise TypeError ("cannot handle value %r of type %r" % (x , type (x )))
1754- shape = x .shape if isinstance (x , numpy .ndarray ) else ()
1755- return Data (
1756- name = name ,
1757- shape = shape , batch_dim_axis = 0 if with_batch_dim else None , time_dim_axis = None ,
1758- dtype = dtype )
1754+ raise TypeError ("%r: cannot handle value %r of type %r" % (name , x , type (x )))
1755+ shape_ = x .shape if isinstance (x , numpy .ndarray ) else ()
1756+ if shape is not None :
1757+ assert len (shape ) == len (shape_ ), "%r: shape does not match in ndim, %r vs %r" % (name , shape , shape_ )
1758+ else :
1759+ shape = shape_
1760+ dim_tags = []
1761+ for i , (d , d_ ) in enumerate (zip (shape , shape_ )):
1762+ assert isinstance (d_ , int )
1763+ if isinstance (d , DimensionTag ):
1764+ assert d .dimension == d_
1765+ elif isinstance (d , int ):
1766+ assert d == d_
1767+ d = DimensionTag (
1768+ kind = DimensionTag .Types .Spatial if i < len (shape ) - 1 else DimensionTag .Types .Feature ,
1769+ description = "%s:static:%i" % (name , i ),
1770+ dimension = d )
1771+ else :
1772+ raise TypeError ("%r shape[%i] invalid type %r in shape %r" % (name , i , type (d ), shape ))
1773+ dim_tags .append (d )
1774+ if with_batch_dim :
1775+ dim_tags .insert (0 , BatchDim )
1776+ return Data (name = name , dim_tags = dim_tags , dtype = dtype )
17591777
17601778 def sanity_check (self , ignore_placeholder = False , assume_complete = True ):
17611779 """
0 commit comments