@@ -16,16 +16,16 @@ def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
1616
1717 def new_detach (self ):
1818 t_ = self ._unpad_detach ()
19- t_ .padding_dim = self .padding_dim
20- t_ .origin_length = self .origin_length
21- t_ .current_length = self .current_length
19+ t_ ._padding_dim = self ._padding_dim
20+ t_ ._origin_length = self ._origin_length
21+ t_ ._current_length = self ._current_length
2222 return t_
2323
2424 def new_clone (self , * args , ** kwargs ):
2525 t_ = self ._unpad_clone (* args , ** kwargs )
26- t_ .padding_dim = self .padding_dim
27- t_ .origin_length = self .origin_length
28- t_ .current_length = self .current_length
26+ t_ ._padding_dim = self ._padding_dim
27+ t_ ._origin_length = self ._origin_length
28+ t_ ._current_length = self ._current_length
2929 return t_
3030
3131 # bind the new methods to the tensor
@@ -63,7 +63,7 @@ def is_padded_tensor(tensor: torch.Tensor) -> bool:
6363 Returns:
6464 bool: Whether the given tensor is a padding tensor.
6565 """
66- return hasattr (tensor , "padding_dim " )
66+ return hasattr (tensor , "_padding_dim " )
6767
6868
6969def to_padded_tensor (
@@ -89,9 +89,9 @@ def to_padded_tensor(
8989 )
9090 tensor .data = torch .cat ((tensor .data , padding_data ), dim = padding_dim ).contiguous ()
9191
92- setattr ( tensor , "padding_dim" , padding_dim )
93- setattr ( tensor , "origin_length" , origin_length )
94- setattr ( tensor , "current_length" , current_length )
92+ tensor . _padding_dim = padding_dim
93+ tensor . _origin_length = origin_length
94+ tensor . _current_length = current_length
9595
9696 _hijack_detach_and_clone (tensor )
9797
@@ -103,25 +103,25 @@ def to_unpadded_tensor(ptensor: torch.Tensor):
103103 return ptensor
104104
105105 unpad_slices = [slice (None )] * ptensor .dim ()
106- unpad_slices [ptensor .padding_dim ] = slice (None , ptensor .origin_length )
106+ unpad_slices [ptensor ._padding_dim ] = slice (None , ptensor ._origin_length )
107107 ptensor .data = ptensor .data [tuple (unpad_slices )]
108108
109- delattr (ptensor , "padding_dim " )
110- delattr (ptensor , "origin_length " )
111- delattr (ptensor , "current_length " )
109+ delattr (ptensor , "_padding_dim " )
110+ delattr (ptensor , "_origin_length " )
111+ delattr (ptensor , "_current_length " )
112112
113113 _hijack_back_detach_and_clone (ptensor )
114114
115115 return ptensor
116116
117117
118- def init_as_ptensor (tensor : torch .Tensor , current_length : int , origin_length : int , padding_dim : int ):
118+ def init_as_padded_tensor (tensor : torch .Tensor , current_length : int , origin_length : int , padding_dim : int ):
119119 if is_padded_tensor (tensor ):
120120 return tensor
121121
122- setattr ( tensor , "padding_dim" , padding_dim )
123- setattr ( tensor , "origin_length" , origin_length )
124- setattr ( tensor , "current_length" , current_length )
122+ tensor . _padding_dim = padding_dim
123+ tensor . _origin_length = origin_length
124+ tensor . _current_length = current_length
125125
126126 _hijack_detach_and_clone (tensor )
127127
0 commit comments