|
7 | 7 | from metatensor import Labels, LabelsEntry, TensorBlock, TensorMap |
8 | 8 |
|
9 | 9 | from copy import deepcopy |
10 | | -from typing import List, Optional |
| 10 | +from typing import List, Optional, Union |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | from torch.nn import Module, ModuleDict |
@@ -167,23 +167,36 @@ class Linear(ModuleTensorMap): |
167 | 167 | properties, the labels of the properties cannot be persevered. |
168 | 168 |
|
169 | 169 | :param bias: |
170 | | - See :py:class:`torch.nn.Linear` |
| 170 | + See :py:class:`torch.nn.Linear` for bool as input. For each TensorMap key the |
| 171 | + bias can be also individually tuend by using a TensorMap with one value for the |
| 172 | + bool. |
171 | 173 | """ |
172 | 174 |
|
173 | 175 | def __init__( |
174 | 176 | self, |
175 | 177 | in_tensor: TensorMap, |
176 | 178 | out_tensor: TensorMap, |
177 | | - bias: bool = True, |
| 179 | + bias: Union[bool, TensorMap] = True, |
178 | 180 | ): |
| 181 | + if isinstance(bias, bool): |
| 182 | + blocks = [ |
| 183 | + TensorBlock( |
| 184 | + values=torch.tensor(bias).reshape(1, 1), |
| 185 | + samples=Labels.range("_", 1), |
| 186 | + components=[], |
| 187 | + properties=Labels.range("_", 1), |
| 188 | + ) |
| 189 | + for _ in in_tensor.keys |
| 190 | + ] |
| 191 | + bias = TensorMap(keys=in_tensor.keys, blocks=blocks) |
179 | 192 | module_map = ModuleDict() |
180 | 193 | for key, in_block in in_tensor.items(): |
181 | 194 | module_key = ModuleTensorMap.module_key(key) |
182 | 195 | out_block = out_tensor.block(key) |
183 | 196 | module = torch.nn.Linear( |
184 | 197 | len(in_block.properties), |
185 | 198 | len(out_block.properties), |
186 | | - bias, |
| 199 | + bias.block(key).values.flatten()[0], |
187 | 200 | in_block.values.device, |
188 | 201 | in_block.values.dtype, |
189 | 202 | ) |
|
0 commit comments