Skip to content

Load TorchScript modules #644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ __API Changes__:

Add torch.utils.rnn<br/>
Add torchvision.io<br/>
Add Tensor.trace() and torch.trace() (unrelated to torch.jit.trace)<br/>
Add ability to load and save TorchScript modules created using Pytorch<br/>

## NuGet Version 0.96.8

Expand Down
25 changes: 25 additions & 0 deletions docfx/articles/torchscript.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Loading TorchScript Modules

Starting with release 0.96.9, you can load TorchScript modules and functions that have been either traced or scripted in Pytorch. It is, however, not yet possible to create a TorchScript module from scratch using TorchSharp. Refer to the [Pytorch JIT](https://pytorch.org/docs/stable/jit.html) docs for information on how to create such a file.

TorchScript is very powerful, because it allows you to save the logic and the weights of a model together, and it furthermore allows the module to be loaded into another program, __without any dependencies on the Python runtime.__ Thus, you can load a model that has been serialized using TorchScript and have it behave as any TorchScript module -- you can use it for training, or you can use it for inference.

Once you have a TorchScript file, you can load it into TorchSharp using:

```C#
var m = torch.jit.load("file-name");
```

It returns a ScriptModule, which behaves just like any other TorchSharp module. Whether the original script came from a module or a function, it is deserialized as a module. You can use it for training of inference by calling either `train()` or `eval()`. ScriptModules always start out on the CPU, so you have to call `cuda()` in order to move it to a GPU.

Note that if you used __tracing__ to create the TorchScript file in Pytorch, submodules that behave differently in training and eval modes will behave according to the mode they were traced in.

If you use the script module to train, you may want / need to save it afterwards.

That is easily done using `save()`:

```C#
torch.jit.save(m, "file-name");
```

While it is possible to save a modified ScriptModule from TorchSharp, it is not (yet) possible to create one _from scratch_ using either tracing or scripting. Another limitation is that the TorchSharp code assumes that the `forward()` function takes only tensors as its arguments and returns a single tensor, a limitation it shares with other TorchSharp modules.
10 changes: 3 additions & 7 deletions docfx/index.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@

TorchSharp are .NET bindings to the Torch library published
here:
TorchSharp are .NET bindings to the Torch library published here:

https://pytorch.org/get-started/locally/

This surfaces the C API as a strongly-typed C# API.
This surfaces the C++ library as a strongly-typed .NET API.

## Getting Started

Expand All @@ -18,7 +17,4 @@ Then, start by reading up on [creating your own modules](articles/modules.md).

An intruction on how to [share model](articles/saveload.md) weights between applications, whether in Python or .NET.


## API documentation

The [API Documentation](api/TorchSharp.html)
Loading existing TorchScript files is now supported and described in [Loading TorchScript](articles/torchscript.md).
197 changes: 161 additions & 36 deletions src/Native/LibTorchSharp/THSJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,48 @@

JITModule THSJIT_load(const char* filename)
{
auto res = torch::jit::load(filename);
auto copy = new torch::jit::Module(res);
return new std::shared_ptr<torch::jit::Module>(copy);
CATCH(
auto res = torch::jit::load(filename);
auto copy = new torch::jit::Module(res);
return new std::shared_ptr<torch::jit::Module>(copy);
);

return nullptr;
}

void THSJIT_save(JITModule module, const char* filename)
{
CATCH(
(*module)->save(filename);
);
}

int THSJIT_Module_is_training(JITModule module)
{
return (*module)->is_training();
}

void THSJIT_Module_train(JITModule module, bool on)
{
(*module)->train(on);
}

void THSJIT_Module_eval(JITModule module)
{
(*module)->eval();
}

void THSJIT_Module_to_device(JITModule module, int64_t device, int64_t index)
{
c10::DeviceType dev = c10::kCPU;
if (device == 1)
dev = c10::kCUDA;
(*module)->to(torch::Device(dev, index));
}

void THSJIT_Module_to_dtype(JITModule module, int8_t dtype)
{
(*module)->to((at::ScalarType)dtype);
}

void THSJIT_Module_modules(const JITModule module, JITModule* (*allocator)(size_t length))
Expand Down Expand Up @@ -35,6 +74,22 @@ void THSJIT_Module_named_modules(const JITModule module,
}
}

void THSJIT_Module_named_children(const JITModule module,
JITModule* (*allocator)(size_t length),
const char** (*allocator2)(size_t length))
{
auto modules = (*module)->named_children();
JITModule* result = allocator(modules.size());
const char** names = allocator2(modules.size());
int i = 0;
for (const auto& child : modules) {
auto copy = new torch::jit::Module(child.value);
result[i] = new std::shared_ptr<torch::jit::Module>(copy);
names[i] = make_sharable_string(child.name);
i++;
}
}

void THSJIT_Module_parameters(const JITModule module, Tensor* (*allocator)(size_t length))
{
auto parameters = (*module)->parameters();
Expand All @@ -60,6 +115,21 @@ void THSJIT_Module_named_parameters(const JITModule module,
}
}

void THSJIT_Module_named_buffers(const JITModule module,
Tensor* (*allocator)(size_t length),
const char** (*allocator2)(size_t length))
{
auto parameters = (*module)->named_buffers();
Tensor* result = allocator(parameters.size());
const char** names = allocator2(parameters.size());
int i = 0;
for (const auto& child : parameters) {
result[i] = new torch::Tensor(child.value);
names[i] = make_sharable_string(child.name);
i++;
}
}

JITMethod THSJIT_Module_get_method(const JITModule module, const char* name)
{
auto method = (*module)->get_method(name);
Expand All @@ -69,7 +139,7 @@ JITMethod THSJIT_Module_get_method(const JITModule module, const char* name)

Tensor THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length)
{
return new torch::Tensor((*module)->forward(toTensors<c10::IValue>((torch::Tensor**)tensorPtrs, length)).toTensor());
CATCH_TENSOR((*module)->forward(toTensors<c10::IValue>((torch::Tensor**)tensorPtrs, length)).toTensor());
}

void THSJIT_Module_dispose(const JITModule module)
Expand All @@ -87,6 +157,16 @@ int THSJIT_Method_num_inputs(const JITMethod method)
return (int)(*method)->num_inputs();
}

int THSJIT_Module_num_inputs(const JITModule module)
{
return (int)(*module)->get_method("forward").num_inputs() - 1; // Don't count the 'self' argument.
}

int THSJIT_Module_num_outputs(const JITModule module)
{
return (int)(*module)->get_method("forward").function().getSchema().returns().size();
}

JITFunction THSJIT_Method_function(const JITMethod method)
{
return new std::shared_ptr<torch::jit::Function>(&(*method)->function());
Expand All @@ -113,32 +193,77 @@ void THSJIT_Function_dispose(const JITFunction function)
delete function;
}

//void* THSJIT_typeCast(const JITType type)
//{
// switch ((*type)->kind())
// {
// case c10::TypeKind::TensorType:
// return new std::shared_ptr<torch::jit::TensorType>((*type)->cast<c10::TensorType>());
// case c10::TypeKind::DimensionedTensorType:
// return new std::shared_ptr<torch::jit::DimensionedTensorType>((*type)->cast<c10::DimensionedTensorType>());
// default:
// return NULL;
// }
//}
//
//int8_t THSJIT_typeKind(const JITType type)
//{
// switch ((*type)->kind())
// {
// case c10::TypeKind::TensorType:
// return (int8_t)TypeKind::TensorType;
// case c10::TypeKind::DimensionedTensorType:
// return (int8_t)TypeKind::DimensionedTensorType;
// default:
// return -1;
// }
//}
//
void THSJIT_Type_dispose(const JITType type)
{
delete type;
}

void THSJIT_TensorType_dispose(const JITTensorType type)
{
delete type;
}

void* THSJIT_Type_cast(const JITType type)
{
switch ((*type)->kind())
{
case c10::TypeKind::TensorType:
return new std::shared_ptr<torch::jit::TensorType>((*type)->cast<c10::TensorType>());
//case c10::TypeKind::DimensionedTensorType:
// return new std::shared_ptr<torch::jit::DimensionedTensorType>((*type)->cast<c10::DimensionedTensorType>());
default:
return NULL;
}
}

int8_t THSJIT_TensorType_dtype(const JITTensorType type)
{
auto scT = (*type)->scalarType();
if (scT.has_value()) {
return (int8_t)scT.value();
}
else {
return -1;
}
}

void THSJIT_TensorType_sizes(const JITTensorType type, int64_t* (*allocator)(int64_t length))
{
//CATCH(
auto& t = *type;
auto dim = t->dim();
auto res = (*type)->sizes().concrete_sizes();
if (res.has_value()) {
const size_t sz = res.value().size();
auto& vec = res.value();
int64_t* result = allocator(sz);
for (size_t i = 0; i < sz; i++)
result[i] = vec[i];
}
//);
}

int8_t THSJIT_Type_kind(const JITType type)
{
switch ((*type)->kind())
{
case c10::TypeKind::TensorType:
return (int8_t)TypeKind::TensorType;
//case c10::TypeKind::DimensionedTensorType:
// return (int8_t)TypeKind::DimensionedTensorType;
default:
return -1;
}
}

JITType THSJIT_Module_getInputType(JITModule module, int8_t index)
{
auto typ = (*module)->type();
c10::TypeKind kind = typ->kind();
auto& schema = typ->getMethod("forward").getSchema();
return new std::shared_ptr<c10::Type>(schema.arguments()[1 + index].type()->cast<c10::TensorType>());
}

//int8_t THSJIT_getScalarFromDimensionedTensorType(const JITDimensionedTensorType type)
//{
// return (int8_t)(*type)->scalarType();
Expand All @@ -159,10 +284,10 @@ void THSJIT_Function_dispose(const JITFunction function)
//
// return make_sharable_string(device_type);
//}
//

//
//void THSJIT_typeDispose(const JITType type)
//{
// delete type;
//}


void THSJIT_typeDispose(const JITType type)
{
delete type;
}
58 changes: 43 additions & 15 deletions src/Native/LibTorchSharp/THSJIT.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,66 @@

#include "Utils.h"

//// Copied from libtorch to share the type as an int8_t.
//enum TypeKind : int8_t {
//#define DEFINE_TYPE(T) T,
// C10_FORALL_TYPES(DEFINE_TYPE)
//#undef DEFINE_TYPE
//};
//
//// API.
// Copied from libtorch to share the type as an int8_t.
enum TypeKind : int8_t {
#define DEFINE_TYPE(T) T,
C10_FORALL_TYPES(DEFINE_TYPE)
#undef DEFINE_TYPE
};

// API.


EXPORT_API(JITModule) THSJIT_load(const char* filename);
EXPORT_API(void) THSJIT_save(JITModule module, const char* filename);

EXPORT_API(void) THSJIT_Module_modules(const JITModule module, JITModule* (*allocator)(size_t length));
EXPORT_API(void) THSJIT_Module_dispose(const JITModule module);

EXPORT_API(int) THSJIT_Module_num_inputs(const JITModule method);
EXPORT_API(int) THSJIT_Module_num_outputs(const JITModule method);

EXPORT_API(Tensor) THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length);

EXPORT_API(int) THSJIT_Module_is_training(JITModule module);
EXPORT_API(void) THSJIT_Module_train(JITModule module, bool on);
EXPORT_API(void) THSJIT_Module_eval(JITModule module);

EXPORT_API(void) THSJIT_Module_to_device(JITModule module, int64_t device, int64_t index);
EXPORT_API(void) THSJIT_Module_to_dtype(JITModule module, int8_t dtype);

EXPORT_API(JITType) THSJIT_Module_getInputType(JITModule module, int8_t dtype);

EXPORT_API(int8_t) THSJIT_Type_kind(JITType handle);
EXPORT_API(void*) THSJIT_Type_cast(const JITType type);

EXPORT_API(int8_t) THSJIT_TensorType_dtype(const JITTensorType type);
EXPORT_API(void) THSJIT_TensorType_sizes(const JITTensorType type, int64_t* (*allocator)(int64_t length));

EXPORT_API(void) THSJIT_Type_dispose(const JITType type);
EXPORT_API(void) THSJIT_TensorType_dispose(const JITTensorType type);

EXPORT_API(void) THSJIT_Module_modules(const JITModule module, JITModule* (*allocator)(size_t length));
EXPORT_API(void) THSJIT_Module_named_modules(const JITModule module,
JITModule* (*allocator)(size_t length),
const char** (*allocator2)(size_t length));

EXPORT_API(void) THSJIT_Module_named_children(const JITModule module,
JITModule* (*allocator)(size_t length),
const char** (*allocator2)(size_t length));

EXPORT_API(JITMethod) THSJIT_Module_get_method(const JITModule module, const char* name);

EXPORT_API(void) THSJIT_Module_parameters(const JITModule module, Tensor* (*allocator)(size_t length));

EXPORT_API(void) THSJIT_Module_named_parameters(const JITModule module,
Tensor* (*allocator)(size_t length),
const char** (*allocator2)(size_t length));

EXPORT_API(Tensor) THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length);

EXPORT_API(void) THSJIT_Module_dispose(const JITModule module);

EXPORT_API(const char*) THSJIT_Method_name(const JITMethod method);
EXPORT_API(void) THSJIT_Module_named_buffers(const JITModule module,
Tensor* (*allocator)(size_t length),
const char** (*allocator2)(size_t length));

EXPORT_API(int) THSJIT_Method_num_inputs(const JITMethod method);

EXPORT_API(void) THSJIT_Method_dispose(const JITMethod method);

EXPORT_API(const char*) THSJIT_Method_name(const JITMethod method);
5 changes: 5 additions & 0 deletions src/Native/LibTorchSharp/THSLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,11 @@ Tensor THSTensor_diag(const Tensor tensor, const int64_t diagonal)
CATCH_TENSOR(tensor->diag(diagonal));
}

Tensor THSTensor_trace(const Tensor tensor)
{
CATCH_TENSOR(tensor->trace());
}

Tensor THSTensor_diagflat(const Tensor tensor, const int64_t offset)
{
CATCH_TENSOR(tensor->diagflat(offset));
Expand Down
Loading