Skip to content

Commit

Permalink
Mnist dataset wrapping success! Not sure how to do other datasets wit…
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jan 9, 2021
1 parent ec43ff1 commit 45a2687
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
20 changes: 17 additions & 3 deletions flambeau/raw_bindings/data_api.nim
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ import
# This should ease searching PyTorch and libtorch documentation,
# and make C++ tutorials easily applicable.

# Headers
# -----------------------------------------------------------------------

{.passC: "-I" & headersPath.}
{.passC: "-I" & torchHeadersPath.}

{.push header: torchHeader.}

# #######################################################################
#
# Datasets
Expand Down Expand Up @@ -47,21 +55,27 @@ type
## with T being Example[Data, Target]
## BatchRequest is by default ArrayRef[csize_t]

# TODO: https://github.com/nim-lang/Nim/issues/16653
# generics + {.inheritable.} doesn't work
# TODO: https://github.com/nim-lang/Nim/issues/16655
# CRTP + importcpp don't work
Dataset*
{.bycopy, pure,
importcpp: "torch::data::datasets::Dataset".}
[Self, Batch]
# [Self, Batch]
= object of BatchDataset # [Self, Batch, ArrayRef[csize_t]]
## A Dataset type
## Self: is the class type that implements the Dataset API
## (using the Curious Recurring Template Pattern in underlying C++)
## Batch is by default the type CppVector[T]
## with T being Example[Data, Target]

# TODO: https://github.com/nim-lang/Nim/issues/16655
# CRTP + importcpp don't work
Mnist*
{.bycopy, pure,
importcpp: "torch::data::datasets::MNIST".}
= object of Dataset[Mnist, CppVector[Example[Tensor, Tensor]]]
= object of Dataset # [Mnist, CppVector[Example[Tensor, Tensor]]]
## The MNIST dataset
## http://yann.lecun.com/exdb/mnist

Expand All @@ -71,7 +85,7 @@ type
kTrain = 0
kTest = 1

func mnist*(rootPath: cstring, mode = kTrain): Mnist {.constructor, importcpp:"MNIST(@)".}
func mnist*(rootPath: cstring, mode = kTrain): Mnist {.constructor, importcpp:"torch::data::datasets::MNIST(@)".}
## Loads the MNIST dataset from the `root` path
## The supplied `rootpath` should contain the *content* of the unzipped
## MNIST dataset, available from http://yann.lecun.com/exdb/mnist.
Expand Down
6 changes: 3 additions & 3 deletions flambeau/raw_bindings/tensors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ else:
# Headers
# -----------------------------------------------------------------------

const headersPath = libTorchPath & "/include"
const torchHeadersPath = headersPath / "torch/csrc/api/include"
const torchHeader = torchHeadersPath / "torch/torch.h"
const headersPath* = libTorchPath & "/include"
const torchHeadersPath* = headersPath / "torch/csrc/api/include"
const torchHeader* = torchHeadersPath / "torch/torch.h"

{.passC: "-I" & headersPath.}
{.passC: "-I" & torchHeadersPath.}
Expand Down
17 changes: 10 additions & 7 deletions proof_of_concepts/poc07_datasets.nim
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ import
data_api, tensors
]

let mnist = mnist("build/mnist")
proc main() =
let mnist = mnist("build/mnist")

echo "Data"
# mnist.get(0).data.print()
# echo "\n-----------------------"
# echo "Target"
# mnist.get(0).target.print()
# echo "\n-----------------------"
echo "Data"
mnist.get(0).data.print()
echo "\n-----------------------"
echo "Target"
mnist.get(0).target.print()
echo "\n-----------------------"

main()

0 comments on commit 45a2687

Please sign in to comment.