-
Notifications
You must be signed in to change notification settings - Fork 152
Closed
Labels
Description
using StaticArrays: @SArray
@SArray Float32[1, 2]
gives
ERROR: UndefVarError: SArray not defined
indeed:
julia> macroexpand(Main, :(@SArray [1, 2]))
:((StaticArrays.SArray{Tuple{2},T,N,L} where L where N where T)((1, 2)))
julia> macroexpand(Main, :(@SArray Float32[1, 2]))
:(SArray{Tuple{2}, Float32}((1, 2)))
The definition
Lines 82 to 235 in a3bca35
macro SArray(ex) | |
if !isa(ex, Expr) | |
error("Bad input for @SArray") | |
end | |
if ex.head == :vect # vector | |
return esc(Expr(:call, SArray{Tuple{length(ex.args)}}, Expr(:tuple, ex.args...))) | |
elseif ex.head == :ref # typed, vector | |
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{length(ex.args)-1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) | |
elseif ex.head == :hcat # 1 x n | |
s1 = 1 | |
s2 = length(ex.args) | |
return esc(Expr(:call, SArray{Tuple{s1, s2}}, Expr(:tuple, ex.args...))) | |
elseif ex.head == :typed_hcat # typed, 1 x n | |
s1 = 1 | |
s2 = length(ex.args) - 1 | |
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) | |
elseif ex.head == :vcat | |
if isa(ex.args[1], Expr) && ex.args[1].head == :row # n x m | |
# Validate | |
s1 = length(ex.args) | |
s2s = map(i -> ((isa(ex.args[i], Expr) && ex.args[i].head == :row) ? length(ex.args[i].args) : 1), 1:s1) | |
s2 = minimum(s2s) | |
if maximum(s2s) != s2 | |
error("Rows must be of matching lengths") | |
end | |
exprs = [ex.args[i].args[j] for i = 1:s1, j = 1:s2] | |
return esc(Expr(:call, SArray{Tuple{s1, s2}}, Expr(:tuple, exprs...))) | |
else # n x 1 | |
return esc(Expr(:call, SArray{Tuple{length(ex.args), 1}}, Expr(:tuple, ex.args...))) | |
end | |
elseif ex.head == :typed_vcat | |
if isa(ex.args[2], Expr) && ex.args[2].head == :row # typed, n x m | |
# Validate | |
s1 = length(ex.args) - 1 | |
s2s = map(i -> ((isa(ex.args[i+1], Expr) && ex.args[i+1].head == :row) ? length(ex.args[i+1].args) : 1), 1:s1) | |
s2 = minimum(s2s) | |
if maximum(s2s) != s2 | |
error("Rows must be of matching lengths") | |
end | |
exprs = [ex.args[i+1].args[j] for i = 1:s1, j = 1:s2] | |
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{s1, s2}, ex.args[1]), Expr(:tuple, exprs...))) | |
else # typed, n x 1 | |
return esc(Expr(:call, Expr(:curly, :SArray, Tuple{length(ex.args)-1, 1}, ex.args[1]), Expr(:tuple, ex.args[2:end]...))) | |
end | |
elseif isa(ex, Expr) && ex.head == :comprehension | |
if length(ex.args) != 1 || !isa(ex.args[1], Expr) || ex.args[1].head != :generator | |
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]") | |
end | |
ex = ex.args[1] | |
n_rng = length(ex.args) - 1 | |
rng_args = [ex.args[i+1].args[1] for i = 1:n_rng] | |
rngs = Any[Core.eval(__module__, ex.args[i+1].args[2]) for i = 1:n_rng] | |
rng_lengths = map(length, rngs) | |
f = gensym() | |
f_expr = :($f = ($(Expr(:tuple, rng_args...)) -> $(ex.args[1]))) | |
# TODO figure out a generic way of doing this... | |
if n_rng == 1 | |
exprs = [:($f($j1)) for j1 in rngs[1]] | |
elseif n_rng == 2 | |
exprs = [:($f($j1, $j2)) for j1 in rngs[1], j2 in rngs[2]] | |
elseif n_rng == 3 | |
exprs = [:($f($j1, $j2, $j3)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3]] | |
elseif n_rng == 4 | |
exprs = [:($f($j1, $j2, $j3, $j4)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4]] | |
elseif n_rng == 5 | |
exprs = [:($f($j1, $j2, $j3, $j4, $j5)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5]] | |
elseif n_rng == 6 | |
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6]] | |
elseif n_rng == 7 | |
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7]] | |
elseif n_rng == 8 | |
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7, $j8)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7], j8 in rngs[8]] | |
else | |
error("@SArray only supports up to 8-dimensional comprehensions") | |
end | |
return quote | |
$(esc(f_expr)) | |
$(esc(Expr(:call, Expr(:curly, :SArray, Tuple{rng_lengths...}), Expr(:tuple, exprs...)))) | |
end | |
elseif isa(ex, Expr) && ex.head == :typed_comprehension | |
if length(ex.args) != 2 || !isa(ex.args[2], Expr) || ex.args[2].head != :generator | |
error("Expected generator in typed comprehension, e.g. Float64[f(i,j) for i = 1:3, j = 1:3]") | |
end | |
T = ex.args[1] | |
ex = ex.args[2] | |
n_rng = length(ex.args) - 1 | |
rng_args = [ex.args[i+1].args[1] for i = 1:n_rng] | |
rngs = [Core.eval(__module__, ex.args[i+1].args[2]) for i = 1:n_rng] | |
rng_lengths = map(length, rngs) | |
f = gensym() | |
f_expr = :($f = ($(Expr(:tuple, rng_args...)) -> $(ex.args[1]))) | |
# TODO figure out a generic way of doing this... | |
if n_rng == 1 | |
exprs = [:($f($j1)) for j1 in rngs[1]] | |
elseif n_rng == 2 | |
exprs = [:($f($j1, $j2)) for j1 in rngs[1], j2 in rngs[2]] | |
elseif n_rng == 3 | |
exprs = [:($f($j1, $j2, $j3)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3]] | |
elseif n_rng == 4 | |
exprs = [:($f($j1, $j2, $j3, $j4)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4]] | |
elseif n_rng == 5 | |
exprs = [:($f($j1, $j2, $j3, $j4, $j5)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5]] | |
elseif n_rng == 6 | |
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6]] | |
elseif n_rng == 7 | |
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7]] | |
elseif n_rng == 8 | |
exprs = [:($f($j1, $j2, $j3, $j4, $j5, $j6, $j7, $j8)) for j1 in rngs[1], j2 in rngs[2], j3 in rngs[3], j4 in rngs[4], j5 in rngs[5], j6 in rngs[6], j7 in rngs[7], j8 in rngs[8]] | |
else | |
error("@SArray only supports up to 8-dimensional comprehensions") | |
end | |
return quote | |
$(esc(f_expr)) | |
$(esc(Expr(:call, Expr(:curly, :SArray, Tuple{rng_lengths...}, T), Expr(:tuple, exprs...)))) | |
end | |
elseif isa(ex, Expr) && ex.head == :call | |
if ex.args[1] == :zeros || ex.args[1] == :ones || ex.args[1] == :rand || ex.args[1] == :randn || ex.args[1] == :randexp | |
if length(ex.args) == 1 | |
error("@SArray got bad expression: $(ex.args[1])()") | |
else | |
return quote | |
if isa($(esc(ex.args[2])), DataType) | |
$(ex.args[1])($(esc(Expr(:curly, SArray, Expr(:curly, Tuple, ex.args[3:end]...), ex.args[2])))) | |
else | |
$(ex.args[1])($(esc(Expr(:curly, SArray, Expr(:curly, Tuple, ex.args[2:end]...))))) | |
end | |
end | |
end | |
elseif ex.args[1] == :fill | |
if length(ex.args) == 1 | |
error("@SArray got bad expression: $(ex.args[1])()") | |
elseif length(ex.args) == 2 | |
error("@SArray got bad expression: $(ex.args[1])($(ex.args[2]))") | |
else | |
return quote | |
$(esc(ex.args[1]))($(esc(ex.args[2])), SArray{$(esc(Expr(:curly, Tuple, ex.args[3:end]...)))}) | |
end | |
end | |
else | |
error("@SArray only supports the zeros(), ones(), rand(), randn(), and randexp() functions.") | |
end | |
else | |
error("Bad input for @SArray") | |
end | |
end |