Skip to content

Commit

Permalink
Correctly name complex data types. (#1091)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel authored Aug 16, 2023
1 parent 0abadc2 commit 64d8f0e
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 3 deletions.
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,10 @@ cpp_torch_qint32 <- function() {
.Call(`_torch_cpp_torch_qint32`)
}

cpp_torch_chalf <- function() {
.Call(`_torch_cpp_torch_chalf`)
}

cpp_torch_cfloat <- function() {
.Call(`_torch_cpp_torch_cfloat`)
}
Expand Down
17 changes: 14 additions & 3 deletions R/dtype.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ dtype_from_string <- function(str) {
"bool" = torch_bool(),
"quint8" = torch_quint8(),
"qint8" = torch_qint8(),
"qint32" = torch_qint32()
"qint32" = torch_qint32(),
"chalf" = torch_chalf(),
"cfloat" = torch_cfloat(),
"cdouble" = torch_cdouble()
)
}

Expand Down Expand Up @@ -79,18 +82,26 @@ torch_float64 <- function() torch_dtype$new(cpp_torch_float64())
torch_double <- function() torch_dtype$new(cpp_torch_float64())


#' @rdname torch_dtype
#' @export
torch_cfloat32 <- function() torch_dtype$new(cpp_torch_chalf())
#' @rdname torch_dtype
#' @export
torch_chalf <- function() torch_dtype$new(cpp_torch_chalf())

#' @rdname torch_dtype
#' @export
torch_cfloat <- function() torch_dtype$new(cpp_torch_cfloat())
#' @rdname torch_dtype
#' @export
torch_cfloat32 <- function() torch_dtype$new(cpp_torch_cfloat())
torch_cfloat64 <- function() torch_dtype$new(cpp_torch_cfloat())

#' @rdname torch_dtype
#' @export
torch_cdouble <- function() torch_dtype$new(cpp_torch_cdouble())
#' @rdname torch_dtype
#' @export
torch_cfloat64 <- function() torch_dtype$new(cpp_torch_cdouble())
torch_cfloat128 <- function() torch_dtype$new(cpp_torch_cdouble())

#' @rdname torch_dtype
#' @export
Expand Down
3 changes: 3 additions & 0 deletions inst/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ LANTERN_OPTIONAL_DECLS(string_view)
HOST_API void * lantern_TensorOptions_pinned_memory(void *self, bool pinned_memory) {LANTERN_CHECK_LOADED void * ret = _lantern_TensorOptions_pinned_memory(self, pinned_memory); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_float32)();
HOST_API void * lantern_Dtype_float32() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_float32(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_chalf)();
HOST_API void * lantern_Dtype_chalf() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_chalf(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_cfloat)();
HOST_API void * lantern_Dtype_cfloat() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_cfloat(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_cdouble)();
Expand Down Expand Up @@ -9813,6 +9815,7 @@ bool lanternInit(const std::string &libPath, std::string *pError)
LOAD_SYMBOL(_lantern_Dtype_int16);
LOAD_SYMBOL(_lantern_Dtype_int32);
LOAD_SYMBOL(_lantern_Dtype_int64);
LOAD_SYMBOL(_lantern_Dtype_chalf);
LOAD_SYMBOL(_lantern_Dtype_cfloat);
LOAD_SYMBOL(_lantern_Dtype_cdouble);
LOAD_SYMBOL(_lantern_Dtype_bool);
Expand Down
11 changes: 11 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,16 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// cpp_torch_chalf
torch::Dtype cpp_torch_chalf();
RcppExport SEXP _torch_cpp_torch_chalf() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(cpp_torch_chalf());
return rcpp_result_gen;
END_RCPP
}
// cpp_torch_cfloat
torch::Dtype cpp_torch_cfloat();
RcppExport SEXP _torch_cpp_torch_cfloat() {
Expand Down Expand Up @@ -49151,6 +49161,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_torch_cpp_torch_quint8", (DL_FUNC) &_torch_cpp_torch_quint8, 0},
{"_torch_cpp_torch_qint8", (DL_FUNC) &_torch_cpp_torch_qint8, 0},
{"_torch_cpp_torch_qint32", (DL_FUNC) &_torch_cpp_torch_qint32, 0},
{"_torch_cpp_torch_chalf", (DL_FUNC) &_torch_cpp_torch_chalf, 0},
{"_torch_cpp_torch_cfloat", (DL_FUNC) &_torch_cpp_torch_cfloat, 0},
{"_torch_cpp_torch_cdouble", (DL_FUNC) &_torch_cpp_torch_cdouble, 0},
{"_torch_cpp_set_default_dtype", (DL_FUNC) &_torch_cpp_set_default_dtype, 1},
Expand Down
5 changes: 5 additions & 0 deletions src/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ XPtrTorchDtype cpp_torch_qint32() {
return XPtrTorchDtype(lantern_Dtype_qint32());
}

// [[Rcpp::export]]
torch::Dtype cpp_torch_chalf() {
return torch::Dtype(lantern_Dtype_chalf());
}

// [[Rcpp::export]]
torch::Dtype cpp_torch_cfloat() {
return torch::Dtype(lantern_Dtype_cfloat());
Expand Down
3 changes: 3 additions & 0 deletions src/lantern/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ LANTERN_OPTIONAL_DECLS(string_view)
HOST_API void * lantern_TensorOptions_pinned_memory(void *self, bool pinned_memory) {LANTERN_CHECK_LOADED void * ret = _lantern_TensorOptions_pinned_memory(self, pinned_memory); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_float32)();
HOST_API void * lantern_Dtype_float32() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_float32(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_chalf)();
HOST_API void * lantern_Dtype_chalf() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_chalf(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_cfloat)();
HOST_API void * lantern_Dtype_cfloat() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_cfloat(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_cdouble)();
Expand Down Expand Up @@ -9813,6 +9815,7 @@ bool lanternInit(const std::string &libPath, std::string *pError)
LOAD_SYMBOL(_lantern_Dtype_int16);
LOAD_SYMBOL(_lantern_Dtype_int32);
LOAD_SYMBOL(_lantern_Dtype_int64);
LOAD_SYMBOL(_lantern_Dtype_chalf);
LOAD_SYMBOL(_lantern_Dtype_cfloat);
LOAD_SYMBOL(_lantern_Dtype_cdouble);
LOAD_SYMBOL(_lantern_Dtype_bool);
Expand Down
1 change: 1 addition & 0 deletions src/lantern/src/Dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ LANTERN_DTYPE_FUN(bool, kBool)
LANTERN_DTYPE_FUN(quint8, kQUInt8)
LANTERN_DTYPE_FUN(qint8, kQInt8)
LANTERN_DTYPE_FUN(qint32, kQInt32)
LANTERN_DTYPE_FUN(chalf, kComplexHalf)
LANTERN_DTYPE_FUN(cfloat, kComplexFloat)
LANTERN_DTYPE_FUN(cdouble, kComplexDouble)
LANTERN_DTYPE_FUN(byte, kByte)
Expand Down

0 comments on commit 64d8f0e

Please sign in to comment.