From 3e417ed27c18f6db2830030c40cbe0eaf6c77eea Mon Sep 17 00:00:00 2001 From: Guillaume Ausset Date: Thu, 18 Mar 2021 11:38:22 +0100 Subject: [PATCH] Adapt wrappers for batch wrappers --- Project.toml | 1 + src/batched/batchedadjtrans.jl | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/Project.toml b/Project.toml index 014cd1969..d8871dad2 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.7.16" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/batched/batchedadjtrans.jl b/src/batched/batchedadjtrans.jl index 9623d33d5..f9cc8773d 100644 --- a/src/batched/batchedadjtrans.jl +++ b/src/batched/batchedadjtrans.jl @@ -1,6 +1,7 @@ using LinearAlgebra import Base: - +import Adapt: adapt_structure, adapt _batched_doc = """ batched_transpose(A::AbstractArray{T,3}) @@ -100,3 +101,6 @@ function rrule(::typeof(batched_adjoint), A::AbstractArray{<:Any,3}) b_adjoint_back(Δ) = (NO_FIELDS, batched_adjoint(Δ)) batched_adjoint(A), b_adjoint_back end + +adapt_structure(to, x::BatchedAdjoint) = BatchedAdjoint(adapt(to, parent(x))) +adapt_structure(to, x::BatchedTranspose) = BatchedTranspose(adapt(to, parent(x))) \ No newline at end of file