@@ -41,6 +41,7 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
4141 super ().__init__ (name , scheme , seed )
4242 self .weights = ParameterizedDefaultDict (self ._create_weight )
4343 self .inverses = ParameterizedDefaultDict (self ._create_inverse )
44+ self ._shared_tensors_device = None
4445
4546 def create_transform (self , module : Module , args : TransformArgs ):
4647 """
@@ -52,19 +53,34 @@ def create_transform(self, module: Module, args: TransformArgs):
5253 """
5354 assert hasattr (module , "weight" )
5455 size = get_transform_size (module , args .location , self .scheme .head_dim )
55- dtype = self .scheme .precision
5656 device = get_offloaded_device (module )
5757
58- weight = self .weights [size , dtype , device ]
58+ factory_kwargs = {"device" : device }
59+ weight = self .weights .get (size , factory_kwargs = factory_kwargs )
5960 if args .inverse :
6061 weight = self .inverses [weight ]
6162
6263 return RandomMatrixTransform (weight , self .scheme , args , type (module ))
6364
64- def _create_weight (self , size : int , dtype : dtype , device : device ) -> Parameter :
65- # TODO: verify that weight is invertible (has non-zero determinant)
65+ def _create_weight (self , size : int , device : device ) -> Parameter :
66+ # check that shared tensors device is consistent
67+ if self ._shared_tensors_device is None :
68+ self ._shared_tensors_device = device
69+
70+ if device != self ._shared_tensors_device :
71+ raise NotImplementedError (
72+ "Creating multi-gpu transform weights are not supported as of now due "
73+ "to the limitations of shared tensors across GPUs"
74+ # in the future, tensors can be shared within GPUs,
75+ # and can be all-reduced during updates and compression
76+ )
77+
78+ # TODO: construct such that weight is invertible (has non-zero determinant)
6679 data = torch .rand (
67- (size , size ), generator = self .generator , dtype = dtype , device = device
80+ (size , size ),
81+ generator = self .generator ,
82+ dtype = self .scheme .precision ,
83+ device = device ,
6884 )
6985 return Parameter (data , requires_grad = self .scheme .requires_grad )
7086
0 commit comments