Skip to content

Commit

Permalink
And the end-to-end example compiles!
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jan 10, 2021
1 parent 9e000fa commit 10236e9
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ build/*
vendor/libtorch
*.zip
nimcache/*
*.pt
7 changes: 5 additions & 2 deletions flambeau/raw_bindings/serialize.nim
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import ./tensors, ./neural_nets
import
../cpp/std_cpp,
./tensors,
./neural_nets

# (Almost) raw bindings to PyTorch serialization
# -----------------------------------------------------------------------
Expand All @@ -31,4 +34,4 @@ import ./tensors, ./neural_nets
# #######################################################################
# libtorch/include/torch/csrc/api/include/torch/optim/optimizer.h

proc save*(module: Module, path: cstring){.sideeffect, importcpp:"torch::save(@)".}
proc save*[T](module: CppSharedPtr[T], path: cstring){.sideeffect, importcpp:"torch::save(@)".}
26 changes: 8 additions & 18 deletions proof_of_concepts/poc09_end_to_end.nim
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,14 @@ import
../flambeau,
std/[enumerate, strformat]

# Argh, need Linear{nullptr} in the codegen
# so we cheat by inlining C++
#
# type Net {.pure.} = object of Module
#
# fc1: Linear
# fc2: Linear
# fc3: Linear
# Net is defined in poc_09_end_to_end_types.nim.hpp
# to work around https://github.com/nim-lang/Nim/issues/16664
# which workarounds https://github.com/nim-lang/Nim/issues/4687

{.emit:["""
struct Net: public torch::nn::Module {
torch::nn::Linear fc1{nullptr};
torch::nn::Linear fc2{nullptr};
torch::nn::Linear fc3{nullptr};
};
"""].}

type Net{.pure, importcpp.} = object of Module
type Net
{.pure, importcpp,
header:"poc09_end_to_end_types.nim.hpp".}
= object of Module
fc1: Linear
fc2: Linear
fc3: Linear
Expand Down Expand Up @@ -72,6 +62,6 @@ proc main() =
if batch_index mod 100 == 0:
echo &"Epoch: {epoch} | Batch: {batch_index} | Loss: {loss.item(float32)}"
# Serialize your model periodically as a checkpoint.
net.save("net.pt")
save(net, "net.pt")

main()
22 changes: 22 additions & 0 deletions proof_of_concepts/poc09_end_to_end_types.nim.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// We need Linear{nullptr} in the codegen
// so we would like to cheat by inlining C++
//
// type Net {.pure.} = object of Module
//
// fc1: Linear
// fc2: Linear
// fc3: Linear
//
// https://github.com/nim-lang/Nim/issues/4687
//
// However
// due to https://github.com/nim-lang/Nim/issues/16664
// it needs to be in its own file

#include <torch/torch.h>

struct Net: public torch::nn::Module {
torch::nn::Linear fc1{nullptr};
torch::nn::Linear fc2{nullptr};
torch::nn::Linear fc3{nullptr};
};

0 comments on commit 10236e9

Please sign in to comment.