@@ -220,8 +220,8 @@ private Type GetType(Type type)
220
220
/// <summary>
221
221
/// Invoke the 'forward' function of the script with any number of arguments.
222
222
/// </summary>
223
- /// <param name="objs"> </param>
224
- /// <returns></returns>
223
+ /// <param name="input">Any number of parameters for the forward function. </param>
224
+ /// <returns>An object. </returns>
225
225
/// <remarks>
226
226
/// Only certain types can currently be passed:
227
227
/// 1. Tensor
@@ -238,15 +238,14 @@ private Type GetType(Type type)
238
238
/// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead.
239
239
/// If a tuple contains both tensors and scalars, it is returned as an object[].
240
240
/// </remarks>
241
- /// <exception cref="NotImplementedException"></exception>
242
- public object call ( params object [ ] objs )
241
+ public object forward ( params object [ ] input )
243
242
{
244
243
TensorOrScalar [ ] ptrArray = null ;
245
244
sbyte typeCode = 0 ;
246
245
247
246
using ( var parray = new IndexedPinnedArrays < TensorOrScalar > ( ) ) {
248
247
249
- var tRefsHandle = DetermineArgumentTypeRefs ( objs , out var count , parray ) ;
248
+ var tRefsHandle = DetermineArgumentTypeRefs ( input , out var count , parray ) ;
250
249
251
250
var allocated = parray . Count ;
252
251
@@ -258,6 +257,40 @@ public object call(params object[] objs)
258
257
}
259
258
}
260
259
260
+ /// <summary>
261
+ /// Synonym for `forward`
262
+ /// </summary>
263
+ /// <param name="input">Any number of parameters for the forward function.</param>
264
+ /// <returns>An object.</returns>
265
+ /// <remarks>
266
+ /// Only certain types can currently be passed:
267
+ /// 1. Tensor
268
+ /// 2. Scalar
269
+ /// 3. int/long
270
+ /// 4. double/float
271
+ /// 5. bool
272
+ ///
273
+ /// Only certain types can currently be returned:
274
+ /// 1. Tensor / Scalar
275
+ /// 2. Tuple of Tensor / Scalar
276
+ /// 3. Array (Python list) of Tensor / Scalar
277
+ ///
278
+ /// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead.
279
+ /// If a tuple contains both tensors and scalars, it is returned as an object[].
280
+ ///
281
+ /// Note: this currently does not support hooking the module.
282
+ /// </remarks>
283
+ public object call ( params object [ ] input )
284
+ {
285
+ // TODO: Call pre-hooks, if available.
286
+
287
+ var result = forward ( input ) ;
288
+
289
+ // TODO: Call post-hooks, if available.
290
+
291
+ return result ;
292
+ }
293
+
261
294
/// <summary>
262
295
/// Invoke a function from the script module.
263
296
/// </summary>
@@ -437,7 +470,7 @@ internal static object ProcessReturnValue(string name, IndexedPinnedArrays<Tenso
437
470
// List of scalars and tensors
438
471
var result = new object [ ptrArray . Length ] ;
439
472
for ( var i = 0 ; i < ptrArray . Length ; i ++ ) {
440
- switch ( ptrArray [ i ] . TypeCode ) {
473
+ switch ( ptrArray [ i ] . TypeCode ) {
441
474
case 0 :
442
475
result [ i ] = new Tensor ( ptrArray [ i ] . Handle ) ;
443
476
break ;
@@ -566,6 +599,62 @@ internal static object ProcessReturnValue(string name, IndexedPinnedArrays<Tenso
566
599
public TResult invoke < T , TResult > ( string name , params T [ ] inputs ) => ( TResult ) invoke ( name , inputs ) ;
567
600
}
568
601
602
+
603
+
604
+ /// <summary>
605
+ /// Represents a module that accepts 'hook' to the module logic.
606
+ /// </summary>
607
+ public class HookableScriptModule < TPreHook , TPostHook > : ScriptModule
608
+ {
609
+ internal HookableScriptModule ( IntPtr handle ) : base ( handle )
610
+ {
611
+ }
612
+
613
+ public HookRemover register_forward_hook ( TPostHook hook )
614
+ {
615
+ var key = Guid . NewGuid ( ) . ToString ( ) ;
616
+ post_hooks . Add ( key , hook ) ;
617
+ return new HookRemover ( this , key ) ;
618
+ }
619
+
620
+ public HookRemover register_forward_pre_hook ( TPreHook hook )
621
+ {
622
+ var key = Guid . NewGuid ( ) . ToString ( ) ;
623
+ pre_hooks . Add ( key , hook ) ;
624
+ return new HookRemover ( this , key ) ;
625
+ }
626
+
627
+ private void remove ( string key )
628
+ {
629
+ if ( pre_hooks . ContainsKey ( key ) ) pre_hooks . Remove ( key ) ;
630
+ if ( post_hooks . ContainsKey ( key ) ) post_hooks . Remove ( key ) ;
631
+ }
632
+
633
+ protected Dictionary < string , TPreHook > pre_hooks = new Dictionary < string , TPreHook > ( ) ;
634
+ protected Dictionary < string , TPostHook > post_hooks = new Dictionary < string , TPostHook > ( ) ;
635
+
636
+ /// <summary>
637
+ /// Used to remove a specific hook, following the PyTorch API design.
638
+ /// </summary>
639
+ /// <remarks>The name and namespace of this class is not the same as in PyTorch, but serves the same purpose.</remarks>
640
+ public class HookRemover
641
+ {
642
+ public HookRemover ( HookableScriptModule < TPreHook , TPostHook > module , string key )
643
+ {
644
+ this . module = module ;
645
+ this . key = key ;
646
+ }
647
+
648
+ public void remove ( )
649
+ {
650
+ module . remove ( key ) ;
651
+ }
652
+
653
+ private HookableScriptModule < TPreHook , TPostHook > module ;
654
+ private string key ;
655
+ }
656
+ }
657
+
569
658
/// <summary>
570
659
/// A script module taking any number of tensors as input
571
660
/// </summary>
@@ -593,18 +682,26 @@ internal ScriptModule(IntPtr handle) : base(handle) { }
593
682
/// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead.
594
683
/// If a tuple contains both tensors and scalars, it is returned as an object[].
595
684
/// </remarks>
596
- public TResult call ( params Tensor [ ] tensor )
685
+ public TResult call ( params Tensor [ ] input )
597
686
{
598
- return ( TResult ) base . call ( tensor ) ;
687
+ // TODO: Call pre-hooks, if available.
688
+
689
+ var result = forward ( input ) ;
690
+
691
+ // TODO: Call post-hooks, if available.
692
+
693
+ return result ;
599
694
}
695
+
696
+ public TResult forward ( params Tensor [ ] tensor ) => ( TResult ) base . forward ( tensor ) ;
600
697
}
601
698
602
699
/// <summary>
603
700
/// A script module taking a single argument.
604
701
/// </summary>
605
702
/// <typeparam name="T">The argument type.</typeparam>
606
703
/// <typeparam name="TResult">The return type of the module.</typeparam>
607
- public class ScriptModule < T , TResult > : ScriptModule , torch . nn . IModule < T , TResult >
704
+ public class ScriptModule < T , TResult > : HookableScriptModule < Func < ScriptModule < T , TResult > , T , T > , Func < ScriptModule < T , TResult > , T , TResult , TResult > > , torch . nn . IModule < T , TResult >
608
705
{
609
706
internal ScriptModule ( IntPtr handle ) : base ( handle ) { }
610
707
@@ -627,10 +724,30 @@ internal ScriptModule(IntPtr handle) : base(handle) { }
627
724
/// For returned types, if the number of values returned in a tuple is greaterh than 5, it is returned as an array, instead.
628
725
/// If a tuple contains both tensors and scalars, it is returned as an object[].
629
726
/// </remarks>
630
- public TResult call ( T tensor )
727
+ public TResult call ( T input )
631
728
{
632
- return ( TResult ) base . call ( tensor ) ;
729
+ // Call pre-hooks, if available.
730
+
731
+ foreach ( var hook in pre_hooks . Values ) {
732
+ var modified = hook ( this , input ) ;
733
+ if ( modified is not null )
734
+ input = modified ;
735
+ }
736
+
737
+ var result = forward ( input ) ;
738
+
739
+ // Call post-hooks, if available.
740
+
741
+ foreach ( var hook in post_hooks . Values ) {
742
+ var modified = hook ( this , input , result ) ;
743
+ if ( modified is not null )
744
+ result = modified ;
745
+ }
746
+
747
+ return result ;
633
748
}
749
+
750
+ public TResult forward ( T tensor ) => ( TResult ) base . forward ( tensor ) ;
634
751
}
635
752
636
753
/// <summary>
@@ -639,7 +756,7 @@ public TResult call(T tensor)
639
756
/// <typeparam name="T1">The first argument type.</typeparam>
640
757
/// <typeparam name="T2">The second argument type.</typeparam>
641
758
/// <typeparam name="TResult">The return type of the module.</typeparam>
642
- public class ScriptModule < T1 , T2 , TResult > : ScriptModule , torch . nn . IModule < T1 , T2 , TResult >
759
+ public class ScriptModule < T1 , T2 , TResult > : HookableScriptModule < Func < ScriptModule < T1 , T2 , TResult > , T1 , T2 , ( T1 , T2 ) ? > , Func < ScriptModule < T1 , T2 , TResult > , T1 , T2 , TResult , TResult > > , torch . nn . IModule < T1 , T2 , TResult >
643
760
{
644
761
internal ScriptModule ( IntPtr handle ) : base ( handle ) { }
645
762
@@ -664,8 +781,30 @@ internal ScriptModule(IntPtr handle) : base(handle) { }
664
781
/// </remarks>
665
782
public TResult call ( T1 input1 , T2 input2 )
666
783
{
667
- return ( TResult ) base . call ( input1 , input2 ) ;
784
+ // Call pre-hooks, if available.
785
+
786
+ foreach ( var hook in pre_hooks . Values ) {
787
+ var modified = hook ( this , input1 , input2 ) ;
788
+ if ( modified . HasValue ) {
789
+ input1 = modified . Value . Item1 ;
790
+ input2 = modified . Value . Item2 ;
791
+ }
792
+ }
793
+
794
+ var result = forward ( input1 , input2 ) ;
795
+
796
+ // Call post-hooks, if available.
797
+
798
+ foreach ( var hook in post_hooks . Values ) {
799
+ var modified = hook ( this , input1 , input2 , result ) ;
800
+ if ( modified is not null )
801
+ result = modified ;
802
+ }
803
+
804
+ return result ;
668
805
}
806
+
807
+ public TResult forward ( T1 input1 , T2 input2 ) => ( TResult ) base . forward ( input1 , input2 ) ;
669
808
}
670
809
671
810
/// <summary>
0 commit comments