Skip to content

Commit

Permalink
fix: device printer (#1081)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer authored Jul 14, 2023
1 parent 7b33caa commit af7278f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
15 changes: 8 additions & 7 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# torch (development version)

- New `torch_save` serialization format. It's ~10x faster and since it's based on safetensors, files can be read with any safetensors implementation. (#1071)
- fix printer of torch device (add new line at the end)
- `as.array` now moves tensors to the cpu before copying data into R. (#1080)

# torch 0.11.0
Expand Down Expand Up @@ -322,7 +323,7 @@ everytime backward is called. (#873)
- Added `distr_categorical()` (#576)
- Added `distr_mixture_same_family()` (#576)
- Improve handling of optimizers state and implement `load_state_dict()` and `state_dict()` for optimizers. (#585)
- Added the ability to save R `list`s containing `torch_tensor`s using `torch_save`. This allows us to save the state of optimizers and modules using `torch_save()`. (#586)
- Added the ability to save R `list`s containing `torch_tensor`s using `torch_save`. This allows us to save the state of optimizers and modules using `torch_save()`. (#586)

## Bug fixes

Expand Down Expand Up @@ -394,10 +395,10 @@ everytime backward is called. (#873)

- Removed the PerformanceReporter from tests to get easier to read stack traces. (#449)
- Internal change in the R7 classes so R7 objects are simple external pointer instead of environments. This might cause breaking change if you relied on saving any kind of state in the Tensor object. (#452)
- Internal refactoring making Rcpp aware of some XPtrTorch* types so making it simpler to return them from Rcpp code. This might cause a breaking change if you are relying on `torch_dtype()` being an R6 class. (#451)
- Internal refactoring making Rcpp aware of some XPtrTorch* types so making it simpler to return them from Rcpp code. This might cause a breaking change if you are relying on `torch_dtype()` being an R6 class. (#451)
- Internal changes to auto unwrap arguments from SEXP's in Rcpp. This will make easier to move the dispatcher system to C++ in the future, but already allows us to gain ~30% speedups in small operations. (#454)
- Added a Windows GPU CI workflow (#508).
- Update to LibTorch v1.8 (#513)
- Update to LibTorch v1.8 (#513)
- Moved some parts of the dispatcher to C++ to make it faster. (#520)

# torch 0.2.1
Expand Down Expand Up @@ -483,13 +484,13 @@ everytime backward is called. (#873)
- Fixed bug that made `RandomSampler(replacement = TRUE)` to never take the last
element in the dataset. (84861fa)
- Fixed `torch_topk` and `x$topk` so the returned indexes are 1-based (#280)
- Fixed a bug (#275) that would cause `1 - torch_tensor(1, device = "cuda")` to
- Fixed a bug (#275) that would cause `1 - torch_tensor(1, device = "cuda")` to
fail because `1` was created in the CPU. (#279)
- We now preserve names in the `dataloader` output (#286)
- `torch_narrow`, `Tensor$narrow()` and `Tensor$narrow_copy` are now indexed
- `torch_narrow`, `Tensor$narrow()` and `Tensor$narrow_copy` are now indexed
starting at 1. (#294)
- `Tensor$is_leaf` is now an active method. (#295)
- Fixed bug when passing equations to `torch_einsum`. (#296)
- Fixed bug when passing equations to `torch_einsum`. (#296)
- Fixed `nn_module_list()` to correctly name added modules, otherwise they are not
returned when doing `state_dict()` on it. (#300)
- Fixed bug related to random number seeds when using in-place methods. (#303)
Expand All @@ -501,7 +502,7 @@ everytime backward is called. (#873)
## New features

- Expanded the `utils_data_default_collate` to support converting R objects to
torch tensors when needed. (#269)
torch tensors when needed. (#269)
- Added an `as.matrix` method for torch Tensors. (#282)
- By default we now truncate the output of `print(totrch_tensor(1:40))` if it
spans for more than 30 lines. This is useful for not spamming the console or
Expand Down
4 changes: 2 additions & 2 deletions R/device.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Device <- R7Class(
s <- paste0(s, ", index=", self$index, ")")
}

cat(s)
cat(s, "\n")
}
),
active = list(
Expand Down Expand Up @@ -125,7 +125,7 @@ is_meta_device <- function(x) {
local_device <- function(device, ..., .env = parent.frame()) {
current_device <- cpp_get_current_default_device()
cpp_set_default_device(device)

withr::defer({
cpp_set_default_device(current_device)
}, envir = .env)
Expand Down
8 changes: 6 additions & 2 deletions tests/testthat/test-device.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ test_that("can print meta tensors", {
})

test_that("can modify the device temporarily", {

z <- torch_randn(10, 10)
with_device(device = "meta", {
x <- torch_randn(10, 10)
Expand All @@ -69,10 +69,14 @@ test_that("can modify the device temporarily", {
b <- torch_randn(10, 10)
})
y <- torch_randn(10, 10)

expect_equal(x$device$type, "meta")
expect_equal(y$device$type, "cpu")
expect_equal(z$device$type, "cpu")
expect_equal(a$device$type, "cpu")
expect_equal(b$device$type, "meta")
})

test_that("printer works", {
expect_equal(capture.output(torch_device("cpu")), "torch_device(type='cpu')>\n")
})

0 comments on commit af7278f

Please sign in to comment.