Skip to content
This repository was archived by the owner on Apr 24, 2024. It is now read-only.

Commit 29799cb

Browse files
committed
add support for TensorMap as argument for bias
1 parent 177b4ef commit 29799cb

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

src/equisolve/nn/module_tensor.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from metatensor import Labels, LabelsEntry, TensorBlock, TensorMap
88

99
from copy import deepcopy
10-
from typing import List, Optional
10+
from typing import List, Optional, Union
1111

1212
import torch
1313
from torch.nn import Module, ModuleDict
@@ -167,23 +167,36 @@ class Linear(ModuleTensorMap):
167167
properties, the labels of the properties cannot be persevered.
168168
169169
:param bias:
170-
See :py:class:`torch.nn.Linear`
170+
See :py:class:`torch.nn.Linear` for bool as input. For each TensorMap key the
171+
bias can be also individually tuend by using a TensorMap with one value for the
172+
bool.
171173
"""
172174

173175
def __init__(
174176
self,
175177
in_tensor: TensorMap,
176178
out_tensor: TensorMap,
177-
bias: bool = True,
179+
bias: Union[bool, TensorMap] = True,
178180
):
181+
if isinstance(bias, bool):
182+
blocks = [
183+
TensorBlock(
184+
values=torch.tensor(bias).reshape(1, 1),
185+
samples=Labels.range("_", 1),
186+
components=[],
187+
properties=Labels.range("_", 1),
188+
)
189+
for _ in in_tensor.keys
190+
]
191+
bias = TensorMap(keys=in_tensor.keys, blocks=blocks)
179192
module_map = ModuleDict()
180193
for key, in_block in in_tensor.items():
181194
module_key = ModuleTensorMap.module_key(key)
182195
out_block = out_tensor.block(key)
183196
module = torch.nn.Linear(
184197
len(in_block.properties),
185198
len(out_block.properties),
186-
bias,
199+
bias.block(key).values.flatten()[0],
187200
in_block.values.device,
188201
in_block.values.dtype,
189202
)

0 commit comments

Comments
 (0)