Skip to content

Commit

Permalink
Safer, extensible ﹫inbounds
Browse files Browse the repository at this point in the history
  • Loading branch information
simonster committed Sep 5, 2014
1 parent 49d4132 commit 7cb11d5
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 79 deletions.
4 changes: 4 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ imag{T<:Real}(x::AbstractArray{T}) = zero(x)

getindex(t::AbstractArray, i::Real) = error("indexing not defined for ", typeof(t))

unsafe_getindex(args...) = getindex(args...)

# linear indexing with a single multi-dimensional index
function getindex(A::AbstractArray, I::AbstractArray)
x = similar(A, size(I))
Expand Down Expand Up @@ -441,6 +443,8 @@ setindex!(t::AbstractArray, x, i::Real) =
error("setindex! not defined for ",typeof(t))
setindex!(t::AbstractArray, x) = throw(MethodError(setindex!, (t, x)))

unsafe_setindex!(args...) = setindex!(args...)

## Indexing: handle more indices than dimensions if "extra" indices are 1

# Don't require vector/matrix subclasses to implement more than 1/2 indices,
Expand Down
69 changes: 39 additions & 30 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,21 +241,25 @@ collect(itr) = collect(eltype(itr), itr)

## Indexing: getindex ##

getindex(a::Array) = arrayref(a,1)

getindex(A::Array, i0::Real) = arrayref(A,to_index(i0))
getindex(A::Array, i0::Real, i1::Real) = arrayref(A,to_index(i0),to_index(i1))
getindex(A::Array, i0::Real, i1::Real, i2::Real) =
arrayref(A,to_index(i0),to_index(i1),to_index(i2))
getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real) =
arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3))
getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) =
arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4))
getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) =
arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5))

getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) =
arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5),to_index(I)...)
for (getindexfn, transform) in ((:getindex, x->x), (:unsafe_getindex, x->:(@boundscheck false return $x)))
@eval begin
$getindexfn(a::Array) = $(transform(:(arrayref(a,1))))

$getindexfn(A::Array, i0::Real) = $(transform(:(arrayref(A,to_index(i0)))))
$getindexfn(A::Array, i0::Real, i1::Real) = $(transform(:(arrayref(A,to_index(i0),to_index(i1)))))
$getindexfn(A::Array, i0::Real, i1::Real, i2::Real) =
$(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2)))))
$getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real) =
$(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3)))))
$getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) =
$(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4)))))
$getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) =
$(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5)))))

$getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) =
$(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5),to_index(I)...))))
end
end

# Fast copy using copy! for UnitRange
function getindex(A::Array, I::UnitRange{Int})
Expand Down Expand Up @@ -302,21 +306,26 @@ getindex(A::Array, I::AbstractArray{Bool}) = getindex_bool_1d(A, I)


## Indexing: setindex! ##
setindex!{T}(A::Array{T}, x) = arrayset(A, convert(T,x), 1)

setindex!{T}(A::Array{T}, x, i0::Real) = arrayset(A, convert(T,x), to_index(i0))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5), to_index(I)...)

for (setindexfn, transform) in ((:setindex!, x->x), (:unsafe_setindex!, x->:(@boundscheck false return $x)))
@eval begin
$setindexfn{T}(A::Array{T}, x) = $(transform(:(arrayset(A, convert(T,x), 1))))

$setindexfn{T}(A::Array{T}, x, i0::Real) = $(transform(:(arrayset(A, convert(T,x), to_index(i0)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5), to_index(I)...))))
end
end

function setindex!{T<:Real}(A::Array, x, I::AbstractVector{T})
for i in I
Expand Down
21 changes: 19 additions & 2 deletions base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,25 @@ macro boundscheck(yesno,blk)
$(Expr(:boundscheck,:pop)))
end

macro inbounds(blk)
:(@boundscheck false $(esc(blk)))
function rewrite_ref(getindexfn, setindexfn, ast::Expr)
if ast.head === :ref
ast = Expr(:custom_ref, getindexfn, setindexfn, ast.args...)
end

args = ast.args
for i = 1:arraylen(args)
arg = arrayref(args, i)
if isa(arg, Expr)
arrayset(args, rewrite_ref(getindexfn, setindexfn, arg), i)
end
end

return ast
end
rewrite_ref(getindexfn, setindexfn, x) = x

macro inbounds(ex)
esc(rewrite_ref(:unsafe_getindex, :unsafe_setindex!, ex))
end

macro label(name::Symbol)
Expand Down
4 changes: 3 additions & 1 deletion base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,9 @@ end
## Indexing: getindex ##

function unsafe_bitgetindex(Bc::Vector{Uint64}, i::Int)
return (Bc[@_div64(i-1)+1] & (uint64(1)<<@_mod64(i-1))) != 0
return @inbounds (Bc[@_div64(i-1)+1] & (uint64(1)<<@_mod64(i-1))) != 0
end
unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind)

function getindex(B::BitArray, i::Int)
1 <= i <= length(B) || throw(BoundsError())
Expand Down Expand Up @@ -408,6 +409,7 @@ function unsafe_bitsetindex!(Bc::Array{Uint64}, x::Bool, i::Int)
end
end
end
unsafe_setindex!(v::BitArray, x::Bool, ind::Int) = (Base.unsafe_bitsetindex!(v.chunks, x, ind); v)

setindex!(B::BitArray, x) = setindex!(B, convert(Bool,x), 1)

Expand Down
2 changes: 2 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,8 @@ export
union!,
union,
unique,
unsafe_getindex,
unsafe_setindex!,
values,
,
,
Expand Down
10 changes: 0 additions & 10 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@
nothing
end

unsafe_getindex(v::Real, ind::Int) = v
unsafe_getindex(v::Range, ind::Int) = first(v) + (ind-1)*step(v)
unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind)
unsafe_getindex(v::AbstractArray, ind::Int) = v[ind]
unsafe_getindex(v, ind::Real) = unsafe_getindex(v, to_index(ind))

unsafe_setindex!{T}(v::AbstractArray{T}, x::T, ind::Int) = (v[ind] = x; v)
unsafe_setindex!(v::BitArray, x::Bool, ind::Int) = (Base.unsafe_bitsetindex!(v.chunks, x, ind); v)
unsafe_setindex!{T}(v::AbstractArray{T}, x::T, ind::Real) = unsafe_setindex!(v, x, to_index(ind))

# Version that uses cartesian indexing for src
@ngenerate N typeof(dest) function _getindex!(dest::Array, src::AbstractArray, I::NTuple{N,Union(Int,AbstractVector)}...)
checksize(dest, I...)
Expand Down
2 changes: 2 additions & 0 deletions base/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ getindex(x::Number) = x
getindex(x::Number, i::Integer) = i == 1 ? x : throw(BoundsError())
getindex(x::Number, I::Integer...) = all([i == 1 for i in I]) ? x : throw(BoundsError())
getindex(x::Number, I::Real...) = getindex(x, to_index(i)...)
unsafe_getindex(x::Number, i::Real) = x
unsafe_getindex(x::Number, i::Real...) = x
first(x::Number) = x
last(x::Number) = x

Expand Down
16 changes: 8 additions & 8 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,15 @@ done(r::UnitRange, i) = i==oftype(i,r.stop)+1

## indexing

getindex(r::Range, i::Real) = getindex(r, to_index(i))
unsafe_getindex(r::Range, i::Real) = getindex(r, to_index(i))
unsafe_getindex{T}(r::Range{T}, i::Integer) =
oftype(T, first(r) + (i-1)*step(r))
unsafe_getindex{T}(r::FloatRange{T}, i::Integer) =
oftype(T, (r.start + (i-1)*r.step)/r.divisor)

function getindex{T}(r::Range{T}, i::Integer)
1 <= i <= length(r) || error(BoundsError)
oftype(T, first(r) + (i-1)*step(r))
end
function getindex{T}(r::FloatRange{T}, i::Integer)
1 <= i <= length(r) || error(BoundsError)
oftype(T, (r.start + (i-1)*r.step)/r.divisor)
unsafe_getindex(r, i)
end

function getindex(r::UnitRange, s::UnitRange{Int})
Expand Down Expand Up @@ -509,7 +509,7 @@ function vcat{T}(r::Range{T})
a = Array(T,n)
i = 1
for x in r
@inbounds a[i] = x
@boundscheck false a[i] = x
i += 1
end
return a
Expand All @@ -523,7 +523,7 @@ function vcat{T}(rs::Range{T}...)
i = 1
for r in rs
for x in r
@inbounds a[i] = x
@boundscheck false a[i] = x
i += 1
end
end
Expand Down
73 changes: 45 additions & 28 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@
`(= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs))))))

(define (expand-update-operator op lhs rhs . declT)
(cond ((and (pair? lhs) (eq? (car lhs) 'ref))
(cond ((and (pair? lhs) (or (eq? (car lhs) 'ref) (eq? (car lhs) 'custom_ref)))
;; expand indexing inside op= first, to remove "end" and ":"
(let* ((ex (partially-expand-ref lhs))
(stmts (butlast (cdr ex)))
(refex (last (cdr ex)))
(nuref `(ref ,(caddr refex) ,@(cdddr refex))))
(nuref `(,@(if (eq? (car lhs) 'custom_ref) `(custom_ref ,(cadr lhs) ,(caddr lhs)) `(ref))
,(caddr refex) ,@(cdddr refex))))
`(block ,@stmts
,(expand-update-operator- op nuref rhs declT))))
((and (pair? lhs) (eq? (car lhs) '|::|))
Expand All @@ -202,8 +203,8 @@
(define (dotop? o) (and (symbol? o) (eqv? (string.char (string o) 0) #\.)))

(define (partially-expand-ref e)
(let ((a (cadr e))
(idxs (cddr e)))
(let ((a (if (eq? (car e) 'custom_ref) (cadddr e) (cadr e)))
(idxs (if (eq? (car e) 'custom_ref) (cddddr e) (cddr e))))
(let* ((reuse (and (pair? a)
(contains (lambda (x)
(or (eq? x 'end)
Expand All @@ -217,7 +218,7 @@
(new-idxs stuff) (process-indexes arr idxs)
`(block
,@(append stmts stuff)
(call getindex ,arr ,@new-idxs))))))
(call ,(if (eq? (car e) 'custom_ref) (cadr e) 'getindex) ,arr ,@new-idxs))))))

;; accumulate a series of comparisons, with the given "and" constructor,
;; exit criteria, and "take" function that consumes part of a list,
Expand Down Expand Up @@ -312,6 +313,11 @@
;; inside ref only replace within the first argument
(list* 'ref (replace-end (cadr ex) a n tuples last)
(cddr ex)))
((eq? (car ex) 'custom_ref)
;; inside custom_ref only replace within the third argument
(list* 'custom_ref (cadr ex) (caddr ex)
(replace-end (cadddr ex) a n tuples last)
(cddddr ex)))
(else
(cons (car ex)
(map (lambda (x) (replace-end x a n tuples last))
Expand Down Expand Up @@ -1534,6 +1540,28 @@
e
((get expand-table (car e) map-expand-forms) e)))

(define (expand-setindex a idxs rhs setindexfn)
(let* ((reuse (and (pair? a)
(contains (lambda (x)
(or (eq? x 'end)
(and (pair? x)
(eq? (car x) ':))))
idxs)))
(arr (if reuse (gensy) a))
(stmts (if reuse `((= ,arr ,(expand-forms a))) '())))
(let* ((rrhs (and (pair? rhs) (not (quoted? rhs))))
(r (if rrhs (gensy) rhs))
(rini (if rrhs `((= ,r ,(expand-forms rhs))) '())))
(receive
(new-idxs stuff) (process-indexes arr idxs)
`(block
,@stmts
,.(map expand-forms stuff)
,@rini
,(expand-forms
`(call ,setindexfn ,arr ,r ,@new-idxs))
,r)))))

(define expand-table
(table
'quote identity
Expand Down Expand Up @@ -1611,29 +1639,14 @@

((ref)
;; (= (ref a . idxs) rhs)
(let ((a (cadr (cadr e)))
(idxs (cddr (cadr e)))
(rhs (caddr e)))
(let* ((reuse (and (pair? a)
(contains (lambda (x)
(or (eq? x 'end)
(and (pair? x)
(eq? (car x) ':))))
idxs)))
(arr (if reuse (gensy) a))
(stmts (if reuse `((= ,arr ,(expand-forms a))) '())))
(let* ((rrhs (and (pair? rhs) (not (quoted? rhs))))
(r (if rrhs (gensy) rhs))
(rini (if rrhs `((= ,r ,(expand-forms rhs))) '())))
(receive
(new-idxs stuff) (process-indexes arr idxs)
`(block
,@stmts
,.(map expand-forms stuff)
,@rini
,(expand-forms
`(call setindex! ,arr ,r ,@new-idxs))
,r))))))
(expand-setindex (cadr (cadr e)) (cddr (cadr e))
(caddr e) 'setindex!)
)

((custom_ref)
(expand-setindex (cadddr (cadr e)) (cddddr (cadr e))
(caddr e) (caddr (cadr e)))
)

((|::|)
;; (= (|::| x T) rhs)
Expand Down Expand Up @@ -1676,6 +1689,10 @@
(lambda (e)
(expand-forms (partially-expand-ref e)))

'custom_ref
(lambda (e)
(expand-forms (partially-expand-ref e)))

'curly
(lambda (e)
(expand-forms
Expand Down

0 comments on commit 7cb11d5

Please sign in to comment.