Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pickle Bugfix #392

Merged
merged 4 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .docker/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ dependencies:
- gcc
- gxx
- make
- dill
2 changes: 1 addition & 1 deletion .github/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ dependencies:
- cereal >= 1.3
- nlopt >= 2.7
- pytorch

- dill
1 change: 1 addition & 0 deletions .github/workflows/build-bindings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: binding-tests
on:
push:
branches:
- release
- main
pull_request: {}

Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release

jobs:
build-docs:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-external-lib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release
pull_request: {}

env:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-push-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release

jobs:
docker:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release
pull_request: {}

jobs:
Expand Down
121 changes: 5 additions & 116 deletions bindings/python/package/torch.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,13 @@
import torch

def ExtractTorchTensorData(tensor):
""" Extracts the pointer, shape, and stride from a pytorch tensor and returns a tuple
that can be passed to MParT functions that have been overloaded to accept
(double*, std::tuple<int,int>, std::tuple<int,int>) instead of a Kokkos::View.

Arguments:
------------
tensor: pytorch.Tensor
The pytorch tensor we want to eventually wrap with a Kokkos view.
from .torch_helpers import ExtractTorchTensorData, MpartTorchAutograd

Returns:
------------
Tuple[int, Tuple[int,int], Tuple[int,int]]
A python tuple that contains all information needed to construct a Kokkos::View.
After casting to c++ types using pybind, this output can be passed to the
mpart::ConstructViewFromPointer function.
"""

# Make sure the tensor has double data type
if tensor.dtype != torch.float64:
raise ValueError(f'Currently only tensors with float64 datatype can be converted. Current dtype is {tensor.dtype}')

if len(tensor.shape)==1:
return tensor.data_ptr(), tensor.shape[0], tensor.stride()[0]
elif len(tensor.shape)==2:
return tensor.data_ptr(), tuple(tensor.shape), tuple(tensor.stride())
else:
raise ValueError(f'Currently only 1d and 2d tensors can be converted.')


class MpartTorchAutograd(torch.autograd.Function):

@staticmethod
def forward(ctx, input, coeffs, f, return_logdet):
ctx.save_for_backward(input, coeffs)
ctx.f = f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()

output = torch.zeros(f.outputDim, input.shape[1], dtype=torch.double)
f.EvaluateImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(output))

if return_logdet:
logdet = torch.zeros(input.shape[1], dtype=torch.double)
f.LogDeterminantImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(logdet))
return output.type(input.dtype), logdet.type(input.dtype)
else:
return output.type(input.dtype)

@staticmethod
def backward(ctx, output_sens, logdet_sens=None):
input, coeffs = ctx.saved_tensors
f = ctx.f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()
output_sens_dbl = output_sens.double()

logdet_sens_dbl = None
if logdet_sens is not None:
logdet_sens_dbl = logdet_sens.double()

# Get the gradient wrt input
grad = None
if input.requires_grad:
grad = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.GradientImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(grad))

if logdet_sens is not None:
grad2 = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.LogDeterminantInputGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))
grad += grad2*logdet_sens_dbl[None,:]

coeff_grad = None
if coeffs is not None:
if coeffs.requires_grad:
coeff_grad = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)
f.CoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(coeff_grad))

coeff_grad = coeff_grad.sum(axis=1) # pytorch expects total gradient not per-sample gradient

if logdet_sens is not None:
grad2 = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)

f.LogDeterminantCoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))

coeff_grad += torch.sum(grad2*logdet_sens[None,:],axis=1)

if coeff_grad is not None:
coeff_grad = coeff_grad.type(input.dtype)

if grad is not None:
grad = grad.type(input.dtype)

return grad, coeff_grad, None, None



class TorchParameterizedFunctionBase(torch.nn.Module):
""" Defines a wrapper around the MParT ParameterizedFunctionBase class that
can be used with pytorch.
"""

def __init__(self, f, store_coeffs=True, dtype=torch.double):
def __init__(self, f=None, store_coeffs=True, dtype=torch.double):
super().__init__()

self.f = f
Expand All @@ -129,7 +18,7 @@ def __init__(self, f, store_coeffs=True, dtype=torch.double):
self.coeffs = torch.nn.Parameter(coeff_tensor)
else:
self.coeffs = None

def forward(self, x, coeffs=None):

if coeffs is None:
Expand All @@ -148,7 +37,7 @@ class TorchConditionalMapBase(torch.nn.Module):
This can be done either in the constructor or afterwards.
"""

def __init__(self, f, store_coeffs=True, return_logdet=False, dtype=torch.double):
def __init__(self, f=None, store_coeffs=True, return_logdet=False, dtype=torch.double):
super().__init__()

self.return_logdet = return_logdet
Expand All @@ -159,7 +48,7 @@ def __init__(self, f, store_coeffs=True, return_logdet=False, dtype=torch.double
self.coeffs = torch.nn.Parameter(coeff_tensor)
else:
self.coeffs = None

def forward(self, x, coeffs=None):

if coeffs is None:
Expand Down
115 changes: 115 additions & 0 deletions bindings/python/package/torch_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch

def ExtractTorchTensorData(tensor):
""" Extracts the pointer, shape, and stride from a pytorch tensor and returns a tuple
that can be passed to MParT functions that have been overloaded to accept
(double*, std::tuple<int,int>, std::tuple<int,int>) instead of a Kokkos::View.

Arguments:
------------
tensor: pytorch.Tensor
The pytorch tensor we want to eventually wrap with a Kokkos view.

Returns:
------------
Tuple[int, Tuple[int,int], Tuple[int,int]]
A python tuple that contains all information needed to construct a Kokkos::View.
After casting to c++ types using pybind, this output can be passed to the
mpart::ConstructViewFromPointer function.
"""

# Make sure the tensor has double data type
if tensor.dtype != torch.float64:
raise ValueError(f'Currently only tensors with float64 datatype can be converted. Current dtype is {tensor.dtype}')

if len(tensor.shape)==1:
return tensor.data_ptr(), tensor.shape[0], tensor.stride()[0]
elif len(tensor.shape)==2:
return tensor.data_ptr(), tuple(tensor.shape), tuple(tensor.stride())
else:
raise ValueError(f'Currently only 1d and 2d tensors can be converted.')


class MpartTorchAutograd(torch.autograd.Function):

def __reduce__(self):
return (self.__class__, (None,))

@staticmethod
def forward(ctx, input, coeffs, f, return_logdet):
ctx.save_for_backward(input, coeffs)
ctx.f = f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()

output = torch.zeros(f.outputDim, input.shape[1], dtype=torch.double)
f.EvaluateImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(output))

if return_logdet:
logdet = torch.zeros(input.shape[1], dtype=torch.double)
f.LogDeterminantImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(logdet))
return output.type(input.dtype), logdet.type(input.dtype)
else:
return output.type(input.dtype)

@staticmethod
def backward(ctx, output_sens, logdet_sens=None):
input, coeffs = ctx.saved_tensors
f = ctx.f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()
output_sens_dbl = output_sens.double()

logdet_sens_dbl = None
if logdet_sens is not None:
logdet_sens_dbl = logdet_sens.double()

# Get the gradient wrt input
grad = None
if input.requires_grad:
grad = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.GradientImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(grad))

if logdet_sens is not None:
grad2 = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.LogDeterminantInputGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))
grad += grad2*logdet_sens_dbl[None,:]

coeff_grad = None
if coeffs is not None:
if coeffs.requires_grad:
coeff_grad = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)
f.CoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(coeff_grad))

coeff_grad = coeff_grad.sum(axis=1) # pytorch expects total gradient not per-sample gradient

if logdet_sens is not None:
grad2 = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)

f.LogDeterminantCoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))

coeff_grad += torch.sum(grad2*logdet_sens[None,:],axis=1)

if coeff_grad is not None:
coeff_grad = coeff_grad.type(input.dtype)

if grad is not None:
grad = grad.type(input.dtype)

return grad, coeff_grad, None, None
35 changes: 26 additions & 9 deletions bindings/python/tests/test_TorchWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import mpart as mt
import numpy as np
import dill

if haveTorch:

Expand Down Expand Up @@ -81,6 +82,21 @@ def test_AutogradCoeffs():
loss.backward()
assert tmap2.coeffs.grad is not None

def test_AutogradCoeffAsInput():

opts = mt.MapOptions()
tmap = mt.CreateTriangular(dim,dim,3,opts) # Simple third order map

tmap2 = tmap.torch(store_coeffs=False)

x = torch.randn(numSamps, dim, dtype=torch.double)
coeffs = torch.randn(tmap.numCoeffs, dtype=torch.double)

y = tmap2(x,coeffs)
assert y.shape[0] == numSamps
assert y.shape[1] == dim
assert not y.isnan().any()

def test_TorchMethod():
opts = mt.MapOptions()
tmap = mt.CreateTriangular(dim,dim,3,opts) # Simple third order map
Expand All @@ -96,20 +112,20 @@ def test_TorchMethod():
assert np.all(y.detach().numpy() == tmap.Evaluate(x.T.detach().numpy()).T)
assert np.all(logdet.detach().numpy() == tmap.LogDeterminant(x.T.detach().numpy()))

def test_AutogradCoeffAsInput():

def test_TorchPickle():
opts = mt.MapOptions()
tmap = mt.CreateTriangular(dim,dim,3,opts) # Simple third order map

tmap2 = tmap.torch(store_coeffs=False)

x = torch.randn(numSamps, dim, dtype=torch.double)
coeffs = torch.randn(tmap.numCoeffs, dtype=torch.double)
tmap2 = tmap.torch(store_coeffs=True)
y = tmap2.forward(x)


y = tmap2(x,coeffs)
assert y.shape[0] == numSamps
assert y.shape[1] == dim
assert not y.isnan().any()
map_bytes = dill.dumps(tmap2, dill.HIGHEST_PROTOCOL)
tmap3 = dill.loads(map_bytes)

y2 = tmap3.forward(x)
assert (y2-y).abs().max() < 1e-8


if __name__=='__main__':
Expand All @@ -118,4 +134,5 @@ def test_AutogradCoeffAsInput():
test_Autograd()
test_AutogradCoeffs()
test_TorchMethod()
test_TorchPickle()
test_AutogradCoeffAsInput()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license={file="LICENSE.txt"}
readme="README.md"
requires-python = ">=3.7"
description="A Monotone Parameterization Toolkit"
version="2.2.1"
version="2.2.2"
keywords=["Measure Transport", "Monotone", "Transport Map", "Isotonic Regression", "Triangular", "Knothe-Rosenblatt"]

[project.urls]
Expand Down
Loading
Loading