Skip to content

Commit b8cf4aa

Browse files
Merge pull request #1110 from NiklasGustafsson/bugs
Re-introduce `forward` to ScriptModule
2 parents 824cc06 + 93b3507 commit b8cf4aa

File tree

3 files changed

+208
-14
lines changed

3 files changed

+208
-14
lines changed

RELEASENOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the
66

77
__Bug Fixes__:
88

9+
ScriptModule: adding `forward` and the ability to hook.<br/>
910
Update to SkiaSharp 2.88.6 to avoid the libwebp vulnerability.<br/>
1011
#1105: Dataset files get written to the wrong directory<br/>
1112

src/TorchSharp/JIT/ScriptModule.cs

Lines changed: 152 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ private Type GetType(Type type)
220220
/// <summary>
221221
/// Invoke the 'forward' function of the script with any number of arguments.
222222
/// </summary>
223-
/// <param name="objs"></param>
224-
/// <returns></returns>
223+
/// <param name="input">Any number of parameters for the forward function.</param>
224+
/// <returns>An object.</returns>
225225
/// <remarks>
226226
/// Only certain types can currently be passed:
227227
/// 1. Tensor
@@ -238,15 +238,14 @@ private Type GetType(Type type)
238238
/// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead.
239239
/// If a tuple contains both tensors and scalars, it is returned as an object[].
240240
/// </remarks>
241-
/// <exception cref="NotImplementedException"></exception>
242-
public object call(params object[] objs)
241+
public object forward(params object[] input)
243242
{
244243
TensorOrScalar[] ptrArray = null;
245244
sbyte typeCode = 0;
246245

247246
using (var parray = new IndexedPinnedArrays<TensorOrScalar>()) {
248247

249-
var tRefsHandle = DetermineArgumentTypeRefs(objs, out var count, parray);
248+
var tRefsHandle = DetermineArgumentTypeRefs(input, out var count, parray);
250249

251250
var allocated = parray.Count;
252251

@@ -258,6 +257,40 @@ public object call(params object[] objs)
258257
}
259258
}
260259

260+
/// <summary>
261+
/// Synonym for `forward`
262+
/// </summary>
263+
/// <param name="input">Any number of parameters for the forward function.</param>
264+
/// <returns>An object.</returns>
265+
/// <remarks>
266+
/// Only certain types can currently be passed:
267+
/// 1. Tensor
268+
/// 2. Scalar
269+
/// 3. int/long
270+
/// 4. double/float
271+
/// 5. bool
272+
///
273+
/// Only certain types can currently be returned:
274+
/// 1. Tensor / Scalar
275+
/// 2. Tuple of Tensor / Scalar
276+
/// 3. Array (Python list) of Tensor / Scalar
277+
///
278+
/// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead.
279+
/// If a tuple contains both tensors and scalars, it is returned as an object[].
280+
///
281+
/// Note: this currently does not support hooking the module.
282+
/// </remarks>
283+
public object call(params object[] input)
284+
{
285+
// TODO: Call pre-hooks, if available.
286+
287+
var result = forward(input);
288+
289+
// TODO: Call post-hooks, if available.
290+
291+
return result;
292+
}
293+
261294
/// <summary>
262295
/// Invoke a function from the script module.
263296
/// </summary>
@@ -437,7 +470,7 @@ internal static object ProcessReturnValue(string name, IndexedPinnedArrays<Tenso
437470
// List of scalars and tensors
438471
var result = new object[ptrArray.Length];
439472
for (var i = 0; i < ptrArray.Length; i++) {
440-
switch(ptrArray[i].TypeCode) {
473+
switch (ptrArray[i].TypeCode) {
441474
case 0:
442475
result[i] = new Tensor(ptrArray[i].Handle);
443476
break;
@@ -566,6 +599,62 @@ internal static object ProcessReturnValue(string name, IndexedPinnedArrays<Tenso
566599
public TResult invoke<T, TResult>(string name, params T[] inputs) => (TResult)invoke(name, inputs);
567600
}
568601

602+
603+
604+
/// <summary>
605+
/// Represents a module that accepts 'hook' to the module logic.
606+
/// </summary>
607+
public class HookableScriptModule<TPreHook, TPostHook> : ScriptModule
608+
{
609+
internal HookableScriptModule(IntPtr handle) : base(handle)
610+
{
611+
}
612+
613+
public HookRemover register_forward_hook(TPostHook hook)
614+
{
615+
var key = Guid.NewGuid().ToString();
616+
post_hooks.Add(key, hook);
617+
return new HookRemover(this, key);
618+
}
619+
620+
public HookRemover register_forward_pre_hook(TPreHook hook)
621+
{
622+
var key = Guid.NewGuid().ToString();
623+
pre_hooks.Add(key, hook);
624+
return new HookRemover(this, key);
625+
}
626+
627+
private void remove(string key)
628+
{
629+
if (pre_hooks.ContainsKey(key)) pre_hooks.Remove(key);
630+
if (post_hooks.ContainsKey(key)) post_hooks.Remove(key);
631+
}
632+
633+
protected Dictionary<string, TPreHook> pre_hooks = new Dictionary<string, TPreHook>();
634+
protected Dictionary<string, TPostHook> post_hooks = new Dictionary<string, TPostHook>();
635+
636+
/// <summary>
637+
/// Used to remove a specific hook, following the PyTorch API design.
638+
/// </summary>
639+
/// <remarks>The name and namespace of this class is not the same as in PyTorch, but serves the same purpose.</remarks>
640+
public class HookRemover
641+
{
642+
public HookRemover(HookableScriptModule<TPreHook, TPostHook> module, string key)
643+
{
644+
this.module = module;
645+
this.key = key;
646+
}
647+
648+
public void remove()
649+
{
650+
module.remove(key);
651+
}
652+
653+
private HookableScriptModule<TPreHook, TPostHook> module;
654+
private string key;
655+
}
656+
}
657+
569658
/// <summary>
570659
/// A script module taking any number of tensors as input
571660
/// </summary>
@@ -593,18 +682,26 @@ internal ScriptModule(IntPtr handle) : base(handle) { }
593682
/// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead.
594683
/// If a tuple contains both tensors and scalars, it is returned as an object[].
595684
/// </remarks>
596-
public TResult call(params Tensor[] tensor)
685+
public TResult call(params Tensor[] input)
597686
{
598-
return (TResult)base.call(tensor);
687+
// TODO: Call pre-hooks, if available.
688+
689+
var result = forward(input);
690+
691+
// TODO: Call post-hooks, if available.
692+
693+
return result;
599694
}
695+
696+
public TResult forward(params Tensor[] tensor) => (TResult)base.forward(tensor);
600697
}
601698

602699
/// <summary>
603700
/// A script module taking a single argument.
604701
/// </summary>
605702
/// <typeparam name="T">The argument type.</typeparam>
606703
/// <typeparam name="TResult">The return type of the module.</typeparam>
607-
public class ScriptModule<T, TResult> : ScriptModule, torch.nn.IModule<T, TResult>
704+
public class ScriptModule<T, TResult> : HookableScriptModule<Func<ScriptModule<T, TResult>, T, T>, Func<ScriptModule<T, TResult>, T, TResult, TResult>>, torch.nn.IModule<T, TResult>
608705
{
609706
internal ScriptModule(IntPtr handle) : base(handle) { }
610707

@@ -627,10 +724,30 @@ internal ScriptModule(IntPtr handle) : base(handle) { }
627724
/// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead.
628725
/// If a tuple contains both tensors and scalars, it is returned as an object[].
629726
/// </remarks>
630-
public TResult call(T tensor)
727+
public TResult call(T input)
631728
{
632-
return (TResult)base.call(tensor);
729+
// Call pre-hooks, if available.
730+
731+
foreach (var hook in pre_hooks.Values) {
732+
var modified = hook(this, input);
733+
if (modified is not null)
734+
input = modified;
735+
}
736+
737+
var result = forward(input);
738+
739+
// Call post-hooks, if available.
740+
741+
foreach (var hook in post_hooks.Values) {
742+
var modified = hook(this, input, result);
743+
if (modified is not null)
744+
result = modified;
745+
}
746+
747+
return result;
633748
}
749+
750+
public TResult forward(T tensor) => (TResult)base.forward(tensor);
634751
}
635752

636753
/// <summary>
@@ -639,7 +756,7 @@ public TResult call(T tensor)
639756
/// <typeparam name="T1">The first argument type.</typeparam>
640757
/// <typeparam name="T2">The second argument type.</typeparam>
641758
/// <typeparam name="TResult">The return type of the module.</typeparam>
642-
public class ScriptModule<T1, T2, TResult> : ScriptModule, torch.nn.IModule<T1, T2, TResult>
759+
public class ScriptModule<T1, T2, TResult> : HookableScriptModule<Func<ScriptModule<T1, T2, TResult>, T1, T2, (T1, T2)?>, Func<ScriptModule<T1, T2, TResult>, T1, T2, TResult, TResult>>, torch.nn.IModule<T1, T2, TResult>
643760
{
644761
internal ScriptModule(IntPtr handle) : base(handle) { }
645762

@@ -664,8 +781,30 @@ internal ScriptModule(IntPtr handle) : base(handle) { }
664781
/// </remarks>
665782
public TResult call(T1 input1, T2 input2)
666783
{
667-
return (TResult)base.call(input1, input2);
784+
// Call pre-hooks, if available.
785+
786+
foreach (var hook in pre_hooks.Values) {
787+
var modified = hook(this, input1, input2);
788+
if (modified.HasValue) {
789+
input1 = modified.Value.Item1;
790+
input2 = modified.Value.Item2;
791+
}
792+
}
793+
794+
var result = forward(input1, input2);
795+
796+
// Call post-hooks, if available.
797+
798+
foreach (var hook in post_hooks.Values) {
799+
var modified = hook(this, input1, input2, result);
800+
if (modified is not null)
801+
result = modified;
802+
}
803+
804+
return result;
668805
}
806+
807+
public TResult forward(T1 input1, T2 input2) => (TResult)base.forward(input1, input2);
669808
}
670809

671810
/// <summary>

test/TorchSharpTest/TestJIT.cs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@ namespace TorchSharp
1414
public class TestJIT
1515
{
1616
#if false
17+
[Fact]
18+
public void TestLoadJIT_1()
19+
{
20+
var input = torch.ones(10);
21+
var expected = torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 });
22+
23+
// One linear layer followed by ReLU.
24+
var m = torch.jit.load<Tensor, Tensor>(@"linrelu.script.dat");
25+
if (torch.cuda.is_available()) {
26+
m = m.to(torch.CUDA);
27+
input = input.to(torch.CUDA);
28+
expected = expected.to(torch.CUDA);
29+
}
30+
31+
var t = m.forward(input);
32+
33+
Assert.Equal(new long[] { 6 }, t.shape);
34+
Assert.Equal(torch.float32, t.dtype);
35+
Assert.True(expected.allclose(t));
36+
}
37+
1738
[Fact]
1839
public void TestLoadJIT_Func()
1940
{
@@ -34,7 +55,7 @@ public void TestLoadJIT_Func()
3455
}
3556

3657
[Fact]
37-
public void TestLoadJIT_1()
58+
public void TestLoadJIT_5()
3859
{
3960
var input = torch.ones(10);
4061
var expected = torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 });
@@ -54,6 +75,39 @@ public void TestLoadJIT_1()
5475
Assert.True(expected.allclose(t));
5576
}
5677

78+
[Fact]
79+
public void TestLoadJIT_6()
80+
{
81+
var input = torch.ones(10);
82+
var expected = torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 });
83+
84+
// One linear layer followed by ReLU.
85+
var m = torch.jit.load<Tensor, Tensor>(@"linrelu.script.dat");
86+
if (torch.cuda.is_available()) {
87+
m = m.to(torch.CUDA);
88+
input = input.to(torch.CUDA);
89+
expected = expected.to(torch.CUDA);
90+
}
91+
92+
int i = 0;
93+
m.register_forward_pre_hook((m,t) => { i += 1; return t; });
94+
m.register_forward_pre_hook((m, t) => { i += 2; return t; });
95+
m.register_forward_hook((m, t1, t2) => { i += 4; return t2; });
96+
m.register_forward_hook((m, t1, t2) => { i += 8; return t2; });
97+
98+
var t = m.forward(input);
99+
100+
Assert.Equal(0, i);
101+
102+
t = m.call(input);
103+
104+
Assert.Equal(15, i);
105+
106+
Assert.Equal(new long[] { 6 }, t.shape);
107+
Assert.Equal(torch.float32, t.dtype);
108+
Assert.True(expected.allclose(t));
109+
}
110+
57111
[Fact]
58112
public void TestSaveJIT()
59113
{

0 commit comments

Comments
 (0)