Skip to content
This repository was archived by the owner on Apr 24, 2024. It is now read-only.

Commit 56f856e

Browse files
committed
drafting export_torchscript
1 parent 29799cb commit 56f856e

File tree

3 files changed

+68
-20
lines changed

3 files changed

+68
-20
lines changed

src/equisolve/nn/module_tensor.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,36 @@ def from_module(
242242
module = torch.nn.Linear(in_features, out_features, bias, device, dtype)
243243
return ModuleTensorMap.from_module(in_keys, module, many_to_one, out_tensor)
244244

245+
@classmethod
246+
def from_weights(
247+
cls,
248+
weights: TensorMap,
249+
bias: Optional[TensorMap] = None
250+
):
251+
"""
252+
:param weights:
253+
The weight tensor map from which we create the linear modules.
254+
255+
:param bias:
256+
The weight tensor map from which we create the linear layers.
257+
"""
258+
module_map = ModuleDict()
259+
for key, weights_block in weights.items():
260+
module_key = ModuleTensorMap.module_key(key)
261+
module = torch.nn.Linear(
262+
len(weights_block.samples),
263+
len(weights_block.properties),
264+
bias=False,
265+
device=weights_block.values.device,
266+
dtype=weights_block.values.dtype,
267+
)
268+
module.weight = torch.nn.Parameter(weights_block.values)
269+
if bias is not None:
270+
module.bias = torch.nn.Parameter(bias.block(key).values)
271+
module_map[module_key] = module
272+
273+
return ModuleTensorMap(module_map, weights)
274+
245275
def forward(self, tensor: TensorMap) -> TensorMap:
246276
# added to appear in doc, :inherited-members: is not compatible with torch
247277
return super().forward(tensor)

src/equisolve/numpy/models/linear_model.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
import scipy.linalg
1414
from metatensor import Labels, TensorBlock, TensorMap
1515

16-
from ... import HAS_TORCH
1716
from ...module import NumpyModule, _Estimator
1817
from ...utils.metrics import rmse
1918
from ..utils import array_from_block
2019

2120

22-
class _Ridge(_Estimator):
21+
class Ridge(_Estimator):
2322
r"""Linear least squares with l2 regularization for :class:`metatensor.Tensormap`'s.
2423
2524
Weights :math:`w` are calculated according to
@@ -352,21 +351,15 @@ def score(self, X: TensorMap, y: TensorMap, parameter_key: str) -> float:
352351
y_pred = self.predict(X)
353352
return rmse(y, y_pred, parameter_key)
354353

355-
356-
class NumpyRidge(_Ridge, NumpyModule):
357-
def __init__(self) -> None:
358-
NumpyModule.__init__(self)
359-
_Ridge.__init__(self)
360-
361-
362-
if HAS_TORCH:
363-
import torch
364-
365-
class TorchRidge(_Ridge, torch.nn.Module):
366-
def __init__(self) -> None:
367-
torch.nn.Module.__init__(self)
368-
_Ridge.__init__(self)
369-
370-
Ridge = TorchRidge
371-
else:
372-
Ridge = NumpyRidge
354+
def export_torchscript(self):
355+
from ... import HAS_METATENSOR_TORCH
356+
if not HAS_METATENSOR_TORCH:
357+
raise ImportError(
358+
"To export your model to TorchScript torch needs to be installed. "
359+
"Please install torch, then reimport equisolve or "
360+
"use equisolve.refresh_global_flags()."
361+
)
362+
from ...nn import Linear
363+
# TODO should be metatensor.torch
364+
weights = metatensor.to(self.weights, backend="torch")
365+
return Linear.from_weights(weights)

tests/equisolve_tests/numpy/models/linear_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
from equisolve.numpy.models import Ridge
1515

16+
from equisolve import HAS_METATENSOR_TORCH
17+
1618
from ..utilities import tensor_to_tensormap
1719

1820

@@ -79,6 +81,29 @@ def equisolve_solver_from_numpy_arrays(
7981
clf.fit(X=X, y=y, alpha=alpha, sample_weight=sw, solver=solver)
8082
return clf
8183

84+
@pytest.mark.skipif(
85+
not (HAS_METATENSOR_TORCH), reason="requires metatensor-torch to be run"
86+
)
87+
def test_export_torchscript(self):
88+
"""Test if ridge is working and all shapes are converted correctly.
89+
Test is performed for two blocks.
90+
"""
91+
92+
num_targets = 50
93+
num_properties = 5
94+
95+
# Create input values
96+
X_arr = self.rng.random([2, num_targets, num_properties])
97+
y_arr = self.rng.random([2, num_targets, 1])
98+
99+
X = tensor_to_tensormap(X_arr)
100+
y = tensor_to_tensormap(y_arr)
101+
102+
clf = Ridge()
103+
clf.fit(X=X, y=y)
104+
module = clf.export_torchscript()
105+
# TODO test torchscriptability and check if output match with original Ridge
106+
82107
num_properties = np.array([91])
83108
num_targets = np.array([1000])
84109
means = np.array([-0.5, 0, 0.1])

0 commit comments

Comments
 (0)