From 4e7e55e54e9519183b3eda28ac9e26fc7e587b8a Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 20 Jun 2024 10:35:35 -0400 Subject: [PATCH] Fix expand on GPU (#85) * Fix expand on GPU * Bump to v0.4.7 --- Project.toml | 6 +++++- src/expand.jl | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 1feaf04..1bf431a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,15 @@ name = "ITensorTDVP" uuid = "25707e16-a4db-4a07-99d9-4d67b7af0342" authors = ["Matthew Fishman and contributors"] -version = "0.4.6" +version = "0.4.7" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" @@ -19,9 +21,11 @@ Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0" ITensorTDVPObserversExt = "Observers" [compat] +Adapt = "3, 4" Compat = "4" ITensors = "0.6.10" KrylovKit = "0.6, 0.7, 0.8" +NDTensors = "0.3.31" Observers = "0.2" PackageExtensionCompat = "1" TimerOutputs = "0.5" diff --git a/src/expand.jl b/src/expand.jl index aabfa51..8a93b63 100644 --- a/src/expand.jl +++ b/src/expand.jl @@ -1,3 +1,4 @@ +using Adapt: adapt using ITensors: ITensors, Algorithm, @@ -15,6 +16,7 @@ using ITensors: uniqueinds using ITensors.ITensorMPS: MPO, MPS, apply, dim, linkind, maxlinkdim, orthogonalize using LinearAlgebra: normalize, svd, tr +using NDTensors: unwrap_array_type # Possible improvements: # - Allow a maxdim argument to be passed to `expand`. @@ -86,7 +88,9 @@ function expand( _, λⱼ, basisⱼ = svd(state[j], linds; righttags="bψ_$j,Link") rinds = uniqueinds(basisⱼ, λⱼ) # Make projectorⱼ - idⱼ = prod(r -> denseblocks(δ(scalartype(state), r', dag(r))), rinds) + idⱼ = prod(rinds) do r + return adapt(unwrap_array_type(basisⱼ), denseblocks(δ(scalartype(state), r', dag(r)))) + end projectorⱼ = idⱼ - prime(basisⱼ, rinds) * dag(basisⱼ) # Sum reference density matrices ρⱼ = sum(reference -> prime(reference[j], rinds) * dag(reference[j]), references)