Skip to content

Commit

Permalink
Improve compute unit concept
Browse files Browse the repository at this point in the history
* Rename to AbstractComputeUnit and get_compute_unit (suggested by
  @vchuravy)

* Add AbstractComputeAccelerator (suggested by @ChrisRackauckas)

* Bottom value instead of exception if compute device can't be resolved

* Rename select_computing_device and make dispatch more robust
  (suggestions by @tkf and @jpsamaroo)

* Defend against reference cycle in generic implementation of
  get_compute_unit (pointed out by @tkf)
  • Loading branch information
oschulz committed May 26, 2022
1 parent 490e047 commit 76c686d
Showing 1 changed file with 112 additions and 32 deletions.
144 changes: 112 additions & 32 deletions src/computedevs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,79 +2,159 @@


"""
abstract type AbstractComputingDevice
abstract type AbstractComputeUnit
Supertype for arbitrary computing devices (CPU, GPU, etc.).
`adapt(dev::AbstractComputingDevice, x)` adapts `x` for `dev`.
`adapt(dev::AbstractComputeUnit, x)` adapts `x` for `dev`.
`Sys.total_memory(dev)` and `Sys.free_memory(dev)` return the total and free
memory on the device.
"""
abstract type AbstractComputingDevice end
export AbstractComputingDevice
abstract type AbstractComputeUnit end
export AbstractComputeUnit


"""
struct ComputingDeviceIndependent <: AbstractComputingDevice
struct ComputingDeviceIndependent
`get_computing_device(x) === ComputingDeviceIndependent()` indicates
`get_compute_unit(x) === ComputingDeviceIndependent()` indicates
that `x` is not tied to a specific computing device. This typically
means that x is a statically allocated object.
"""
struct ComputingDeviceIndependent <: AbstractComputingDevice end
struct ComputingDeviceIndependent end
export ComputingDeviceIndependent


"""
struct CPUDevice <: AbstractComputingDevice
UnknownComputeUnitOf(x)
`get_compute_unit(x) === ComputingDeviceIndependent()` indicates
that the computing device for `x` cannot be determined.
"""
struct UnknownComputeUnitOf{T}
x::T
end


"""
struct MixedComputeSystem <: AbstractComputeUnit
A (possibly heterogenous) system of multiple compute units.
"""
struct MixedComputeSystem <: AbstractComputeUnit end
export MixedComputeSystem


"""
struct CPUDevice <: AbstractComputeUnit
`CPUDevice()` is the default CPU device.
"""
struct CPUDevice <: AbstractComputingDevice end
struct CPUDevice <: AbstractComputeUnit end
export CPUDevice

adapt_storage(::CPUDevice, x) = adapt_storage(Array, x)

Sys.total_memory(::CPUDevice) = Sys.total_memory()
Sys.free_memory(::CPUDevice) = Sys.free_memory()


"""
abstract type AbstractGPUDevice <: AbstractComputingDevice
abstract type AbstractComputeAccelerator <: AbstractComputeUnit
Supertype for GPU computing devices.
"""
abstract type AbstractGPUDevice <: AbstractComputingDevice end
abstract type AbstractComputeAccelerator <: AbstractComputeUnit end
export AbstractComputeAccelerator


"""
abstract type AbstractGPUDevice <: AbstractComputeAccelerator
Supertype for GPU computing devices.
"""
abstract type AbstractGPUDevice <: AbstractComputeAccelerator end
export AbstractGPUDevice


merge_compute_units() = ComputingDeviceIndependent()

const _incompatible_devs = ArgumentError("Incompatible computing devices")
@inline function merge_compute_units(a, b, c, ds::Vararg{Any,N}) where N
a_b = merge_compute_units(a,b)
return merge_compute_units(a_b, c, ds...)
end

select_computing_device(a::ComputingDeviceIndependent, ::ComputingDeviceIndependent) = a
select_computing_device(a::ComputingDeviceIndependent, b::AbstractComputingDevice) = b
select_computing_device(a::AbstractComputingDevice, b::ComputingDeviceIndependent) = a
@inline merge_compute_units(a::UnknownComputeUnitOf, b::UnknownComputeUnitOf) = a
@inline merge_compute_units(a::UnknownComputeUnitOf, b::Any) = a
@inline merge_compute_units(a::Any, b::UnknownComputeUnitOf) = b

@inline function merge_compute_units(a, b)
return (a === b) ? a : compute_unit_mergeresult(
compute_unit_mergerule(a, b),
compute_unit_mergerule(b, a),
)
end

select_computing_device(a::CPUDevice, ::CPUDevice) = a
select_computing_device(a::CPUDevice, b::AbstractGPUDevice) = a
select_computing_device(a::AbstractGPUDevice, b::CPUDevice) = b
select_computing_device(a::AbstractGPUDevice, b::AbstractGPUDevice) = (a === b) ? a : throw(_incompatible_devs)
struct NoCUnitMergeRule end

@inline compute_unit_mergerule(a::Any, b::Any) = NoCUnitMergeRule()
@inline compute_unit_mergerule(a::UnknownComputeUnitOf, b::Any) = a
@inline compute_unit_mergerule(a::UnknownComputeUnitOf, b::UnknownComputeUnitOf) = a
@inline compute_unit_mergerule(a::ComputingDeviceIndependent, b::Any) = b

@inline compute_unit_mergeresult(a_b::NoCUnitMergeRule, b_a::NoCUnitMergeRule) = MixedComputeSystem()
@inline compute_unit_mergeresult(a_b, b_a::NoCUnitMergeRule) = a_b
@inline compute_unit_mergeresult(a_b::NoCUnitMergeRule, b_a) = b_a
@inline compute_unit_mergeresult(a_b, b_a) = a_b === b_a ? a_b : MixedComputeSystem()


"""
get_computing_device(x)::AbstractComputingDevice
get_compute_unit(x)::Union{
AbstractComputeUnit,
ComputingDeviceIndependent,
UnknownComputeUnitOf
}
Get the computing device backing object `x`.
Don't specialize `get_compute_unit`, specialize
[`Adapt.get_compute_unit_impl`](@ref) instead.
"""
function get_computing_device end
export get_computing_device
get_compute_unit(x) = get_compute_unit_impl(Union{}, x)
export get_compute_unit


@inline get_computing_device(::Array) = CPUDevice()
"""
get_compute_unit_impl(::Type{TypeHistory}, x)::AbstractComputeUnit
# ToDo: Utilize `ArrayInterfaceCore.buffer(A)`? Would require Adapt to depend
# on ArrayInterfaceCore.
See [`get_compute_unit_impl`](@ref).
@generated function get_computing_device(x)
impl = :(begin dev_0 = ComputingDeviceIndependent() end)
append!(impl.args, [:($(Symbol(:dev_, i)) = select_computing_device(get_computing_device(getfield(x, $i)), $(Symbol(:dev_, i-1)))) for i in 1:fieldcount(x)])
push!(impl.args, :(return $(Symbol(:dev_, fieldcount(x)))))
impl
end
Specializations that directly resolve the compute unit based on `x` can
ignore `TypeHistory`:
```julia
Adapt.get_compute_unit_impl(@nospecialize(TypeHistory::Type), x::SomeType) = ...
```
"""
function get_compute_unit_impl end


adapt_storage(::CPUDevice, x) = adapt_storage(Array, x)
@inline get_compute_unit_impl(@nospecialize(TypeHistory::Type), ::Array) = CPUDevice()

# Guard against object reference loops:
@inline get_compute_unit_impl(::Type{TypeHistory}, x::T) where {TypeHistory,T<:TypeHistory} = begin
UnknownComputeUnitOf(x)
end

@generated function get_compute_unit_impl(::Type{TypeHistory}, x) where TypeHistory
if isbitstype(x)
:(ComputingDeviceIndependent())
else
NewTypeHistory = Union{TypeHistory, x}
impl = :(begin dev_0 = ComputingDeviceIndependent() end)
append!(impl.args, [:($(Symbol(:dev_, i)) = merge_compute_units(get_compute_unit_impl($NewTypeHistory, getfield(x, $i)), $(Symbol(:dev_, i-1)))) for i in 1:fieldcount(x)])
push!(impl.args, :(return $(Symbol(:dev_, fieldcount(x)))))
impl
end
end

0 comments on commit 76c686d

Please sign in to comment.