Skip to content

Commit

Permalink
recover lost functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
dsyme committed May 15, 2020
1 parent d6eb784 commit 32b9559
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 50 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,4 @@ ASALocalRun/
/src/Examples/Data/data_batch_3.bin
/src/Examples/Data/data_batch_2.bin
/src/Examples/Data/data_batch_1.bin
packages/
2 changes: 1 addition & 1 deletion Directory.Build.targets
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
BeforeTargets="PrepareForRun">

<PropertyGroup>
<LibPrefix Condition="'$(OS)' != 'Windows_NT' OR '$(AssumeOS)' != 'windows'">lib</LibPrefix>
<LibPrefix Condition="'$(OS)' != 'Windows_NT' OR '$(AssumeOS)' == 'linux' OR '$(AssumeOS)' == 'macos'">lib</LibPrefix>
<LibExtension Condition="'$(OS)' == 'Windows_NT' OR '$(AssumeOS)' == 'windows'">.dll</LibExtension>
<LibExtension Condition="'$(OS)' != 'Windows_NT' OR '$(AssumeOS)' == 'linux'">.so</LibExtension>
<LibExtension Condition="$([MSBuild]::IsOSPlatform('osx')) OR '$(AssumeOS)' == 'mac'">.dylib</LibExtension>
Expand Down
31 changes: 14 additions & 17 deletions TorchSharp.sln
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,20 @@ Global
{42B45168-476D-4BFA-87B8-81A34E6295CD}.Release|Any CPU.Build.0 = Release|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.Release|x64.ActiveCfg = Release|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.Release|x64.Build.0 = Release|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.RelWithDebInfo|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.RelWithDebInfo|Any CPU.Build.0 = Release-Intrinsics|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.RelWithDebInfo|x64.ActiveCfg = Release-Intrinsics|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.RelWithDebInfo|x64.Build.0 = Release-Intrinsics|Any CPU
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.Debug|Any CPU.ActiveCfg = Debug|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.Debug|x64.ActiveCfg = Debug|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.Debug|x64.Build.0 = Debug|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.MinSizeRel|Any CPU.ActiveCfg = MinSizeRel|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.Release|Any CPU.ActiveCfg = Release|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.Release|x64.ActiveCfg = Release|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.Release|x64.Build.0 = Release|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.RelWithDebInfo|Any CPU.ActiveCfg = RelWithDebInfo|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
{C52CFB19-5D0D-3D9C-9259-5D5A287177C2}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
EndGlobalSection
{42B45168-476D-4BFA-87B8-81A34E6295CD}.RelWithDebInfo|Any CPU.ActiveCfg = Release|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.RelWithDebInfo|Any CPU.Build.0 = Release|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.RelWithDebInfo|x64.ActiveCfg = Release|Any CPU
{42B45168-476D-4BFA-87B8-81A34E6295CD}.RelWithDebInfo|x64.Build.0 = Release|Any CPU
{912E4543-FB5A-3172-BAFE-B3F15EE0A723}.Debug|Any CPU.ActiveCfg = Debug|x64
{912E4543-FB5A-3172-BAFE-B3F15EE0A723}.Debug|x64.ActiveCfg = Debug|x64
{912E4543-FB5A-3172-BAFE-B3F15EE0A723}.Debug|x64.Build.0 = Debug|x64
{912E4543-FB5A-3172-BAFE-B3F15EE0A723}.Release|Any CPU.ActiveCfg = Release|x64
{912E4543-FB5A-3172-BAFE-B3F15EE0A723}.Release|x64.ActiveCfg = Release|x64
{912E4543-FB5A-3172-BAFE-B3F15EE0A723}.Release|x64.Build.0 = Release|x64
{912E4543-FB5A-3172-BAFE-B3F15EE0A723}.RelWithDebInfo|Any CPU.ActiveCfg = RelWithDebInfo|x64
{912E4543-FB5A-3172-BAFE-B3F15EE0A723}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
{912E4543-FB5A-3172-BAFE-B3F15EE0A723}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
EndGlobalSection
Expand Down
56 changes: 55 additions & 1 deletion src/Native/LibTorchSharp/THSJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ long THSJIT_getNumModules(const JITModule module)
return (*module)->get_modules().size();
}

JITModule THSJIT_getModuleFromName(const JITModule module, const char* name)
JITModule THSJIT_getSubModule(const JITModule module, const int index)
{
auto m = (*module)->get_modules()[index];

return new std::shared_ptr<torch::jit::script::Module>(m);
}

JITModule THSJIT_getSubModuleByName(const JITModule module, const char* name)
{
return new std::shared_ptr<torch::jit::script::Module>((*module)->get_module(name));
}
Expand Down Expand Up @@ -48,6 +55,53 @@ JITType THSJIT_getOutputType(const JITModule module, const int n)
return new std::shared_ptr<c10::Type>(type);
}

void* THSJIT_typeCast(const JITType type)
{
switch ((*type)->kind())
{
case c10::TypeKind::TensorType:
return new std::shared_ptr<torch::jit::TensorType>((*type)->cast<c10::TensorType>());
case c10::TypeKind::DimensionedTensorType:
return new std::shared_ptr<torch::jit::DimensionedTensorType>((*type)->cast<c10::DimensionedTensorType>());
default:
return NULL;
}
}

int8_t THSJIT_typeKind(const JITType type)
{
switch ((*type)->kind())
{
case c10::TypeKind::TensorType:
return (int8_t)TypeKind::TensorType;
case c10::TypeKind::DimensionedTensorType:
return (int8_t)TypeKind::DimensionedTensorType;
default:
return -1;
}
}

int8_t THSJIT_getScalarFromDimensionedTensorType(const JITDimensionedTensorType type)
{
return (int8_t)(*type)->scalarType();
}

int THSJIT_getDimensionedTensorTypeDimensions(const JITDimensionedTensorType type)
{
return (*type)->dim();
}

const char* THSJIT_getDimensionedTensorDevice(const JITDimensionedTensorType type)
{
auto device = (*type)->device();

auto device_type = DeviceTypeName(device.type());

std::transform(device_type.begin(), device_type.end(), device_type.begin(), ::tolower);

return make_sharable_string(device_type);
}

Tensor THSJIT_forward(const JITModule module, const Tensor* tensorPtrs, const int length)
{
return new torch::Tensor((*module)->forward(toTensors<c10::IValue>((torch::Tensor**)tensorPtrs, length)).toTensor());
Expand Down
20 changes: 19 additions & 1 deletion src/Native/LibTorchSharp/THSJIT.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ EXPORT_API(JITModule) THSJIT_loadModule(const char* filename);
// Gets the number of submodules contained into the source module.
EXPORT_API(long) THSJIT_getNumModules(const JITModule module);

// Gets the sub-module contained into the input wrapper at the given index.
EXPORT_API(JITModule) THSJIT_getSubModule(const JITModule module, const int index);

// Gets the sub-module contained into the input wrapper with the given name.
EXPORT_API(JITModule) THSJIT_getModuleFromName(const JITModule module, const char* name);
EXPORT_API(JITModule) THSJIT_getSubModuleByName(const JITModule module, const char* name);

// Returns the number of inputs expected by the input module.
EXPORT_API(int) THSJIT_getNumberOfInputs(const JITModule module);
Expand All @@ -37,6 +40,21 @@ EXPORT_API(JITType) THSJIT_getInputType(const JITModule module, const int n);
// Returns the type of the nth-output.
EXPORT_API(JITType) THSJIT_getOutputType(const JITModule module, const int n);

// Cast the input type to the proper type.
EXPORT_API(void*) THSJIT_typeCast(const JITType type);

// Returns the int8_t code for the input type.
EXPORT_API(int8_t) THSJIT_typeKind(const JITType ttype);

// Returns the int8_t code for the raw type of the tensor.
EXPORT_API(int8_t) THSJIT_getScalarFromDimensionedTensorType(const JITDimensionedTensorType type);

// Gets the number of dimensions of the input tensor type.
EXPORT_API(int) THSJIT_getDimensionedTensorTypeDimensions(const JITDimensionedTensorType type);

// Gets the number of device of the input tensor type.
EXPORT_API(const char*) THSJIT_getDimensionedTensorDevice(const JITDimensionedTensorType type);

// Forward pass over the input module using the input tensor.
EXPORT_API(Tensor) THSJIT_forward(const JITModule module, const Tensor * tensorPtrs, const int length);

Expand Down
2 changes: 2 additions & 0 deletions src/Native/LibTorchSharp/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ typedef std::shared_ptr<torch::nn::Module> * NNModule;
typedef std::shared_ptr<torch::optim::Optimizer> * Optimizer;
typedef std::shared_ptr<torch::jit::script::Module> * JITModule;
typedef std::shared_ptr<c10::Type> * JITType;
typedef std::shared_ptr<torch::jit::TensorType>* JITTensorType;
typedef std::shared_ptr<torch::jit::DimensionedTensorType>* JITDimensionedTensorType;

#define THS_API TH_API

Expand Down
2 changes: 1 addition & 1 deletion src/Native/build.proj
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
</PropertyGroup>

<PropertyGroup>
<NativeLibPrefix Condition="'$(OS)' != 'Windows_NT' OR '$(AssumeOS)' != 'windows'">lib</NativeLibPrefix>
<NativeLibPrefix Condition="'$(OS)' != 'Windows_NT' OR '$(AssumeOS)' == 'linux' OR '$(AssumeOS)' == 'macos'">lib</NativeLibPrefix>
<NativeLibExtension Condition="'$(OS)' == 'Windows_NT' OR '$(AssumeOS)' == 'windows'">.dll</NativeLibExtension>
<NativeLibExtension Condition="'$(OS)' != 'Windows_NT' OR '$(AssumeOS)' == 'linux'">.so</NativeLibExtension>
<NativeLibExtension Condition="$([MSBuild]::IsOSPlatform('osx')) OR '$(AssumeOS)' == 'mac'">.dylib</NativeLibExtension>
Expand Down
22 changes: 3 additions & 19 deletions src/TorchSharp/JIT/Module.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation and contributors. All Rights Reserved. See License.txt in the project root for license information.
// Copyright (c) Microsoft Corporation and contributors. All Rights Reserved. See License.txt in the project root for license information.
using System;
using System.Linq;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -87,22 +87,6 @@ static public Module Load(string filename)
[DllImport("LibTorchSharp")]
private static extern long THSJIT_getNumModules(HType module);

[DllImport("LibTorchSharp")]
private static extern string THSJIT_getModuleName(HType module, int index);

public string[] GetSubModulesNames()
{
var numModules = THSJIT_getNumModules(handle);
string[] result = new string[numModules];

for (int i = 0; i < numModules; i++)
{
result[i] = THSJIT_getModuleName(handle, i);
}

return result;
}

[DllImport("LibTorchSharp")]
private static extern int THSJIT_getNumberOfInputs(HType module);

Expand Down Expand Up @@ -143,11 +127,11 @@ private Type GetType(Type type)
{
switch (type.Kind)
{
case Type.TypeKind.DynamicType:
case Type.TypeKind.TensorType:
var dynamic = type.AsDynamicType();
type.Dispose();
return dynamic;
case Type.TypeKind.TensorType:
case Type.TypeKind.DimensionedTensorType:
var tensor = type.AsTensorType();
type.Dispose();
return tensor;
Expand Down
14 changes: 7 additions & 7 deletions src/TorchSharp/JIT/Type/TensorType.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation and contributors. All Rights Reserved. See License.txt in the project root for license information.
// Copyright (c) Microsoft Corporation and contributors. All Rights Reserved. See License.txt in the project root for license information.
using System;
using System.Runtime.InteropServices;

Expand All @@ -19,27 +19,27 @@ internal TensorType(Type type) : base()
}

[DllImport("LibTorchSharp")]
private static extern short THSJIT_getScalarFromTensorType(HType handle);
private static extern short THSJIT_getScalarFromDimensionedTensorType(HType handle);

public Tensor.ATenScalarMapping GetScalarType()
{
return (Tensor.ATenScalarMapping)THSJIT_getScalarFromTensorType(handle);
return (Tensor.ATenScalarMapping)THSJIT_getScalarFromDimensionedTensorType(handle);
}

[DllImport("LibTorchSharp")]
private static extern int THSJIT_getTensorTypeDimensions(HType handle);
private static extern int THSJIT_getDimensionedTensorTypeDimensions(HType handle);

public int GetDimensions()
{
return THSJIT_getTensorTypeDimensions(handle);
return THSJIT_getDimensionedTensorTypeDimensions(handle);
}

[DllImport("LibTorchSharp")]
private static extern string THSJIT_getTensorDevice(HType handle);
private static extern string THSJIT_getDimensionedTensorDevice(HType handle);

public string GetDevice()
{
return THSJIT_getTensorDevice(handle);
return THSJIT_getDimensionedTensorDevice(handle);
}
}
}
6 changes: 3 additions & 3 deletions src/TorchSharp/JIT/Type/Type.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation and contributors. All Rights Reserved. See License.txt in the project root for license information.
// Copyright (c) Microsoft Corporation and contributors. All Rights Reserved. See License.txt in the project root for license information.
using System;
using System.Runtime.InteropServices;

Expand Down Expand Up @@ -101,8 +101,8 @@ internal DynamicType AsDynamicType()

internal enum TypeKind : sbyte
{
DynamicType = 0,
TensorType = 1
TensorType = 0,
DimensionedTensorType = 1
}
}
}

0 comments on commit 32b9559

Please sign in to comment.