Skip to content

Commit 0ccdf08

Browse files
Merge pull request #644 from NiklasGustafsson/jit
Load TorchScript modules
2 parents 2e631a8 + 401d074 commit 0ccdf08

31 files changed

+1169
-427
lines changed

RELEASENOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ __API Changes__:
1010

1111
Add torch.utils.rnn<br/>
1212
Add torchvision.io<br/>
13+
Add Tensor.trace() and torch.trace() (unrelated to torch.jit.trace)<br/>
14+
Add ability to load and save TorchScript modules created using Pytorch<br/>
1315

1416
## NuGet Version 0.96.8
1517

docfx/articles/torchscript.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Loading TorchScript Modules
2+
3+
Starting with release 0.96.9, you can load TorchScript modules and functions that have been either traced or scripted in Pytorch. It is, however, not yet possible to create a TorchScript module from scratch using TorchSharp. Refer to the [Pytorch JIT](https://pytorch.org/docs/stable/jit.html) docs for information on how to create such a file.
4+
5+
TorchScript is very powerful, because it allows you to save the logic and the weights of a model together, and it furthermore allows the module to be loaded into another program, __without any dependencies on the Python runtime.__ Thus, you can load a model that has been serialized using TorchScript and have it behave as any TorchScript module -- you can use it for training, or you can use it for inference.
6+
7+
Once you have a TorchScript file, you can load it into TorchSharp using:
8+
9+
```C#
10+
var m = torch.jit.load("file-name");
11+
```
12+
13+
It returns a ScriptModule, which behaves just like any other TorchSharp module. Whether the original script came from a module or a function, it is deserialized as a module. You can use it for training of inference by calling either `train()` or `eval()`. ScriptModules always start out on the CPU, so you have to call `cuda()` in order to move it to a GPU.
14+
15+
Note that if you used __tracing__ to create the TorchScript file in Pytorch, submodules that behave differently in training and eval modes will behave according to the mode they were traced in.
16+
17+
If you use the script module to train, you may want / need to save it afterwards.
18+
19+
That is easily done using `save()`:
20+
21+
```C#
22+
torch.jit.save(m, "file-name");
23+
```
24+
25+
While it is possible to save a modified ScriptModule from TorchSharp, it is not (yet) possible to create one _from scratch_ using either tracing or scripting. Another limitation is that the TorchSharp code assumes that the `forward()` function takes only tensors as its arguments and returns a single tensor, a limitation it shares with other TorchSharp modules.

docfx/index.md

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11

2-
TorchSharp are .NET bindings to the Torch library published
3-
here:
2+
TorchSharp are .NET bindings to the Torch library published here:
43

54
https://pytorch.org/get-started/locally/
65

7-
This surfaces the C API as a strongly-typed C# API.
6+
This surfaces the C++ library as a strongly-typed .NET API.
87

98
## Getting Started
109

@@ -18,7 +17,4 @@ Then, start by reading up on [creating your own modules](articles/modules.md).
1817

1918
An intruction on how to [share model](articles/saveload.md) weights between applications, whether in Python or .NET.
2019

21-
22-
## API documentation
23-
24-
The [API Documentation](api/TorchSharp.html)
20+
Loading existing TorchScript files is now supported and described in [Loading TorchScript](articles/torchscript.md).

src/Native/LibTorchSharp/THSJIT.cpp

Lines changed: 161 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,48 @@
33

44
JITModule THSJIT_load(const char* filename)
55
{
6-
auto res = torch::jit::load(filename);
7-
auto copy = new torch::jit::Module(res);
8-
return new std::shared_ptr<torch::jit::Module>(copy);
6+
CATCH(
7+
auto res = torch::jit::load(filename);
8+
auto copy = new torch::jit::Module(res);
9+
return new std::shared_ptr<torch::jit::Module>(copy);
10+
);
11+
12+
return nullptr;
13+
}
14+
15+
void THSJIT_save(JITModule module, const char* filename)
16+
{
17+
CATCH(
18+
(*module)->save(filename);
19+
);
20+
}
21+
22+
int THSJIT_Module_is_training(JITModule module)
23+
{
24+
return (*module)->is_training();
25+
}
26+
27+
void THSJIT_Module_train(JITModule module, bool on)
28+
{
29+
(*module)->train(on);
30+
}
31+
32+
void THSJIT_Module_eval(JITModule module)
33+
{
34+
(*module)->eval();
35+
}
36+
37+
void THSJIT_Module_to_device(JITModule module, int64_t device, int64_t index)
38+
{
39+
c10::DeviceType dev = c10::kCPU;
40+
if (device == 1)
41+
dev = c10::kCUDA;
42+
(*module)->to(torch::Device(dev, index));
43+
}
44+
45+
void THSJIT_Module_to_dtype(JITModule module, int8_t dtype)
46+
{
47+
(*module)->to((at::ScalarType)dtype);
948
}
1049

1150
void THSJIT_Module_modules(const JITModule module, JITModule* (*allocator)(size_t length))
@@ -35,6 +74,22 @@ void THSJIT_Module_named_modules(const JITModule module,
3574
}
3675
}
3776

77+
void THSJIT_Module_named_children(const JITModule module,
78+
JITModule* (*allocator)(size_t length),
79+
const char** (*allocator2)(size_t length))
80+
{
81+
auto modules = (*module)->named_children();
82+
JITModule* result = allocator(modules.size());
83+
const char** names = allocator2(modules.size());
84+
int i = 0;
85+
for (const auto& child : modules) {
86+
auto copy = new torch::jit::Module(child.value);
87+
result[i] = new std::shared_ptr<torch::jit::Module>(copy);
88+
names[i] = make_sharable_string(child.name);
89+
i++;
90+
}
91+
}
92+
3893
void THSJIT_Module_parameters(const JITModule module, Tensor* (*allocator)(size_t length))
3994
{
4095
auto parameters = (*module)->parameters();
@@ -60,6 +115,21 @@ void THSJIT_Module_named_parameters(const JITModule module,
60115
}
61116
}
62117

118+
void THSJIT_Module_named_buffers(const JITModule module,
119+
Tensor* (*allocator)(size_t length),
120+
const char** (*allocator2)(size_t length))
121+
{
122+
auto parameters = (*module)->named_buffers();
123+
Tensor* result = allocator(parameters.size());
124+
const char** names = allocator2(parameters.size());
125+
int i = 0;
126+
for (const auto& child : parameters) {
127+
result[i] = new torch::Tensor(child.value);
128+
names[i] = make_sharable_string(child.name);
129+
i++;
130+
}
131+
}
132+
63133
JITMethod THSJIT_Module_get_method(const JITModule module, const char* name)
64134
{
65135
auto method = (*module)->get_method(name);
@@ -69,7 +139,7 @@ JITMethod THSJIT_Module_get_method(const JITModule module, const char* name)
69139

70140
Tensor THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length)
71141
{
72-
return new torch::Tensor((*module)->forward(toTensors<c10::IValue>((torch::Tensor**)tensorPtrs, length)).toTensor());
142+
CATCH_TENSOR((*module)->forward(toTensors<c10::IValue>((torch::Tensor**)tensorPtrs, length)).toTensor());
73143
}
74144

75145
void THSJIT_Module_dispose(const JITModule module)
@@ -87,6 +157,16 @@ int THSJIT_Method_num_inputs(const JITMethod method)
87157
return (int)(*method)->num_inputs();
88158
}
89159

160+
int THSJIT_Module_num_inputs(const JITModule module)
161+
{
162+
return (int)(*module)->get_method("forward").num_inputs() - 1; // Don't count the 'self' argument.
163+
}
164+
165+
int THSJIT_Module_num_outputs(const JITModule module)
166+
{
167+
return (int)(*module)->get_method("forward").function().getSchema().returns().size();
168+
}
169+
90170
JITFunction THSJIT_Method_function(const JITMethod method)
91171
{
92172
return new std::shared_ptr<torch::jit::Function>(&(*method)->function());
@@ -113,32 +193,77 @@ void THSJIT_Function_dispose(const JITFunction function)
113193
delete function;
114194
}
115195

116-
//void* THSJIT_typeCast(const JITType type)
117-
//{
118-
// switch ((*type)->kind())
119-
// {
120-
// case c10::TypeKind::TensorType:
121-
// return new std::shared_ptr<torch::jit::TensorType>((*type)->cast<c10::TensorType>());
122-
// case c10::TypeKind::DimensionedTensorType:
123-
// return new std::shared_ptr<torch::jit::DimensionedTensorType>((*type)->cast<c10::DimensionedTensorType>());
124-
// default:
125-
// return NULL;
126-
// }
127-
//}
128-
//
129-
//int8_t THSJIT_typeKind(const JITType type)
130-
//{
131-
// switch ((*type)->kind())
132-
// {
133-
// case c10::TypeKind::TensorType:
134-
// return (int8_t)TypeKind::TensorType;
135-
// case c10::TypeKind::DimensionedTensorType:
136-
// return (int8_t)TypeKind::DimensionedTensorType;
137-
// default:
138-
// return -1;
139-
// }
140-
//}
141-
//
196+
void THSJIT_Type_dispose(const JITType type)
197+
{
198+
delete type;
199+
}
200+
201+
void THSJIT_TensorType_dispose(const JITTensorType type)
202+
{
203+
delete type;
204+
}
205+
206+
void* THSJIT_Type_cast(const JITType type)
207+
{
208+
switch ((*type)->kind())
209+
{
210+
case c10::TypeKind::TensorType:
211+
return new std::shared_ptr<torch::jit::TensorType>((*type)->cast<c10::TensorType>());
212+
//case c10::TypeKind::DimensionedTensorType:
213+
// return new std::shared_ptr<torch::jit::DimensionedTensorType>((*type)->cast<c10::DimensionedTensorType>());
214+
default:
215+
return NULL;
216+
}
217+
}
218+
219+
int8_t THSJIT_TensorType_dtype(const JITTensorType type)
220+
{
221+
auto scT = (*type)->scalarType();
222+
if (scT.has_value()) {
223+
return (int8_t)scT.value();
224+
}
225+
else {
226+
return -1;
227+
}
228+
}
229+
230+
void THSJIT_TensorType_sizes(const JITTensorType type, int64_t* (*allocator)(int64_t length))
231+
{
232+
//CATCH(
233+
auto& t = *type;
234+
auto dim = t->dim();
235+
auto res = (*type)->sizes().concrete_sizes();
236+
if (res.has_value()) {
237+
const size_t sz = res.value().size();
238+
auto& vec = res.value();
239+
int64_t* result = allocator(sz);
240+
for (size_t i = 0; i < sz; i++)
241+
result[i] = vec[i];
242+
}
243+
//);
244+
}
245+
246+
int8_t THSJIT_Type_kind(const JITType type)
247+
{
248+
switch ((*type)->kind())
249+
{
250+
case c10::TypeKind::TensorType:
251+
return (int8_t)TypeKind::TensorType;
252+
//case c10::TypeKind::DimensionedTensorType:
253+
// return (int8_t)TypeKind::DimensionedTensorType;
254+
default:
255+
return -1;
256+
}
257+
}
258+
259+
JITType THSJIT_Module_getInputType(JITModule module, int8_t index)
260+
{
261+
auto typ = (*module)->type();
262+
c10::TypeKind kind = typ->kind();
263+
auto& schema = typ->getMethod("forward").getSchema();
264+
return new std::shared_ptr<c10::Type>(schema.arguments()[1 + index].type()->cast<c10::TensorType>());
265+
}
266+
142267
//int8_t THSJIT_getScalarFromDimensionedTensorType(const JITDimensionedTensorType type)
143268
//{
144269
// return (int8_t)(*type)->scalarType();
@@ -159,10 +284,10 @@ void THSJIT_Function_dispose(const JITFunction function)
159284
//
160285
// return make_sharable_string(device_type);
161286
//}
162-
//
163287

164-
//
165-
//void THSJIT_typeDispose(const JITType type)
166-
//{
167-
// delete type;
168-
//}
288+
289+
290+
void THSJIT_typeDispose(const JITType type)
291+
{
292+
delete type;
293+
}

src/Native/LibTorchSharp/THSJIT.h

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,66 @@
77

88
#include "Utils.h"
99

10-
//// Copied from libtorch to share the type as an int8_t.
11-
//enum TypeKind : int8_t {
12-
//#define DEFINE_TYPE(T) T,
13-
// C10_FORALL_TYPES(DEFINE_TYPE)
14-
//#undef DEFINE_TYPE
15-
//};
16-
//
17-
//// API.
10+
// Copied from libtorch to share the type as an int8_t.
11+
enum TypeKind : int8_t {
12+
#define DEFINE_TYPE(T) T,
13+
C10_FORALL_TYPES(DEFINE_TYPE)
14+
#undef DEFINE_TYPE
15+
};
16+
17+
// API.
1818

1919

2020
EXPORT_API(JITModule) THSJIT_load(const char* filename);
21+
EXPORT_API(void) THSJIT_save(JITModule module, const char* filename);
2122

22-
EXPORT_API(void) THSJIT_Module_modules(const JITModule module, JITModule* (*allocator)(size_t length));
23+
EXPORT_API(void) THSJIT_Module_dispose(const JITModule module);
24+
25+
EXPORT_API(int) THSJIT_Module_num_inputs(const JITModule method);
26+
EXPORT_API(int) THSJIT_Module_num_outputs(const JITModule method);
27+
28+
EXPORT_API(Tensor) THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length);
29+
30+
EXPORT_API(int) THSJIT_Module_is_training(JITModule module);
31+
EXPORT_API(void) THSJIT_Module_train(JITModule module, bool on);
32+
EXPORT_API(void) THSJIT_Module_eval(JITModule module);
2333

34+
EXPORT_API(void) THSJIT_Module_to_device(JITModule module, int64_t device, int64_t index);
35+
EXPORT_API(void) THSJIT_Module_to_dtype(JITModule module, int8_t dtype);
36+
37+
EXPORT_API(JITType) THSJIT_Module_getInputType(JITModule module, int8_t dtype);
38+
39+
EXPORT_API(int8_t) THSJIT_Type_kind(JITType handle);
40+
EXPORT_API(void*) THSJIT_Type_cast(const JITType type);
41+
42+
EXPORT_API(int8_t) THSJIT_TensorType_dtype(const JITTensorType type);
43+
EXPORT_API(void) THSJIT_TensorType_sizes(const JITTensorType type, int64_t* (*allocator)(int64_t length));
44+
45+
EXPORT_API(void) THSJIT_Type_dispose(const JITType type);
46+
EXPORT_API(void) THSJIT_TensorType_dispose(const JITTensorType type);
47+
48+
EXPORT_API(void) THSJIT_Module_modules(const JITModule module, JITModule* (*allocator)(size_t length));
2449
EXPORT_API(void) THSJIT_Module_named_modules(const JITModule module,
2550
JITModule* (*allocator)(size_t length),
2651
const char** (*allocator2)(size_t length));
2752

53+
EXPORT_API(void) THSJIT_Module_named_children(const JITModule module,
54+
JITModule* (*allocator)(size_t length),
55+
const char** (*allocator2)(size_t length));
56+
2857
EXPORT_API(JITMethod) THSJIT_Module_get_method(const JITModule module, const char* name);
2958

3059
EXPORT_API(void) THSJIT_Module_parameters(const JITModule module, Tensor* (*allocator)(size_t length));
31-
3260
EXPORT_API(void) THSJIT_Module_named_parameters(const JITModule module,
3361
Tensor* (*allocator)(size_t length),
3462
const char** (*allocator2)(size_t length));
3563

36-
EXPORT_API(Tensor) THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length);
37-
38-
EXPORT_API(void) THSJIT_Module_dispose(const JITModule module);
39-
40-
EXPORT_API(const char*) THSJIT_Method_name(const JITMethod method);
64+
EXPORT_API(void) THSJIT_Module_named_buffers(const JITModule module,
65+
Tensor* (*allocator)(size_t length),
66+
const char** (*allocator2)(size_t length));
4167

4268
EXPORT_API(int) THSJIT_Method_num_inputs(const JITMethod method);
4369

4470
EXPORT_API(void) THSJIT_Method_dispose(const JITMethod method);
71+
72+
EXPORT_API(const char*) THSJIT_Method_name(const JITMethod method);

src/Native/LibTorchSharp/THSLinearAlgebra.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,11 @@ Tensor THSTensor_diag(const Tensor tensor, const int64_t diagonal)
293293
CATCH_TENSOR(tensor->diag(diagonal));
294294
}
295295

296+
Tensor THSTensor_trace(const Tensor tensor)
297+
{
298+
CATCH_TENSOR(tensor->trace());
299+
}
300+
296301
Tensor THSTensor_diagflat(const Tensor tensor, const int64_t offset)
297302
{
298303
CATCH_TENSOR(tensor->diagflat(offset));

0 commit comments

Comments
 (0)