Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] more general indexing for KnetArray #229

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/karray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,6 @@ end
# function _getindex(l::LinearIndexing, A::AbstractArray, I::Union{Real, AbstractArray, Colon}...)
# in abstractarray.jl:487,multidimensional.jl:184.

if VERSION < v"0.5.0"
@typealias6 AbstractUnitRange UnitRange
end

function getindex{T}(A::KnetArray{T}, I::AbstractUnitRange)
if !(1 <= first(I) <= last(I) <= length(A)); throw(BoundsError(A,I)); end
Expand Down Expand Up @@ -552,6 +549,24 @@ function setindex!{T}(A::KnetArray{T}, v, I::Colon)
unsafe_copy!(A,1,v,1,length(A))
end

## General Indexing Fallback to linear indexing

function getindex(a::KnetArray, I...)
indx = to_indices(a, I)
crange = CartesianRange(indx)
linind = [sub2ind(size(a), t.I...) for t in crange]
b = getindex(a, vec(linind))
reshape(b, length.(Base.index_shape(indx...)))
end

getindex(a::KnetArray, ::Colon, ::Colon, ::Colon) = a # fix ambiguity with method in rnn.jl

function setindex!(a::KnetArray, v, I...)
crange = CartesianRange(to_indices(a, I))
linind = [sub2ind(size(a), t.I...) for t in crange]
setindex!(a, v, vec(linind))
end

for F in (32,64); T=Symbol("Float$F"); @eval begin

## Indexing with KnetArray{Int32}: low level, only Int32 supported, no bounds checking
Expand Down
186 changes: 112 additions & 74 deletions test/karray.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
include("header.jl")

# http://docs.julialang.org/en/latest/manual/arrays.html#man-supported-index-types-1
if VERSION < v"0.5.0"
Base.IteratorsMD.CartesianIndex(i::Int...)=CartesianIndex(i)
end

# Test KnetArray operations: cat, convert, copy, display, eachindex,
# eltype, endof, fill!, first, getindex, hcat, isempty, length,
Expand All @@ -12,15 +9,17 @@ end

if gpu() >= 0
@testset "karray" begin
a = rand(3,4)
k = KnetArray(a)

a2 = rand(3,4)
k2 = KnetArray(a2)
a3 = rand(3,4,5)
k3 = KnetArray(a3)

# getindex, setindex!
# Index types: Integer, CartesianIndex, Vector{Int}, Array{Int}, EmptyArray, a:c, a:b:c, Colon, Bool
# See http://docs.julialang.org/en/latest/manual/arrays.html#man-supported-index-types-1
# check out http://docs.julialang.org/en/latest/manual/arrays.html#Cartesian-indices-1
@testset "indexing" begin
@test a == k # Supported index types:
@testset "indexing 2d" begin
@test a2 == k2 # Supported index types:
for i in ((:,), (:,:), # Colon, Tuple{Colon}
(3,), (2,3), # Int, Tuple{Int}
(3:5,), (1:2,3:4), # UnitRange, Tuple{UnitRange}
Expand All @@ -33,96 +32,135 @@ if gpu() >= 0
([1,3],:), (:,[1,3]), # Vector{Int},Colon
([2,2],:), (:,[2,2]), # Repeated index
([],), # Empty Array
((a.>0.5),), # BitArray
((a2.>0.5),), # BitArray
([1 3; 2 4],), # Array{Int}
(CartesianIndex(3,),), (CartesianIndex(2,3),), # CartesianIndex
(if VERSION >= v"0.5.0"
[(:,a[1,:].>0.5),(a[:,1].>0.5,:), # BitArray2 # FAIL for julia4
([CartesianIndex(2,2), CartesianIndex(2,1)],)] # Array{CartesianIndex} # FAIL for julia4
else [] end)...
(CartesianIndex(3,),), # CartesianIndex
(CartesianIndex(2,3),),
(:,a2[1,:].>0.5), # BitArray2
(a2[:,1].>0.5,:),
([CartesianIndex(2,2), CartesianIndex(2,1)],) # Array{CartesianIndex}
)
# @show i
@test a[i...] == k[i...]
ai = a[i...]
a[i...] = 0
k[i...] = 0
@test a == k
a[i...] = ai
k[i...] = ai
@test a == k
@test gradcheck(getindex, a, i...)
@test gradcheck(getindex, k, i...)
@test a2[i...] == k2[i...]
ai = a2[i...]
a2[i...] = 0
k2[i...] = 0
@test a2 == k2
a2[i...] = ai
k2[i...] = ai
@test a2 == k2
@test gradcheck(getindex, a2, i...)
@test gradcheck(getindex, k2, i...)
end
# make sure end works
@test a[2:end] == k[2:end]
@test a[2:end,2:end] == k[2:end,2:end]
# k.>0.5 returns KnetArray{T}, no Knet BitArrays yet
@test a[a.>0.5] == k[k.>0.5]
end
@test a2[2:end] == k2[2:end]
@test a2[2:end,2:end] == k2[2:end,2:end]
# k2.>0.5 returns KnetArray{T}, no Knet BitArrays yet
@test a2[a2.>0.5] == k2[k2.>0.5]

# Unsupported indexing etc.:
# @test_broken a[1:2:3,1:3:4] == Array(k[1:2:3,1:3:4]) # MethodError: no method matching getindex(::Knet.KnetArray{Float64,2}, ::StepRange{Int64,Int64}, ::StepRange{Int64,Int64})
# @test_broken a[[3,1],[4,2]] == Array(k[[3,1],[4,2]]) # MethodError: no method matching getindex(::Knet.KnetArray{Float64,2}, ::Array{Int64,1}, ::Array{Int64,1})
# @test_broken cat((1,2),a,a) == Array(cat((1,2),k,k)) # cat only impl for i=1,2
# Unsupported indexing etc.:
# @test_broken a2[1:2:3,1:3:4] == Array(k2[1:2:3,1:3:4]) # MethodError: no method matching getindex(::Knet.KnetArray{Float64,2}, ::StepRange{Int64,Int64}, ::StepRange{Int64,Int64})
# @test_broken a2[[3,1],[4,2]] == Array(k2[[3,1],[4,2]]) # MethodError: no method matching getindex(::Knet.KnetArray{Float64,2}, ::Array{Int64,1}, ::Array{Int64,1})
# @test_broken cat((1,2),a2,a2) == Array(cat((1,2),k2,k2)) # cat only impl for i=1,2
end


@testset "indexing 3d" begin
@test a3 == k3 # Supported index types:
for i in (
(:,),
(:,:,:), # Colon, Tuple{Colon}
(3,), (2,3,2), # Int, Tuple{Int}
(3:5,),
(1:2,3:4,3), # UnitRange, Tuple{UnitRange}
(2,:,:), (1,:,2), # Int, Colon
(1:2,:,1), (1:3,:,1:2),(:,1:2,2), # UnitRange,Colon
(1:2,2,2), (2,2:2,1:2), # Int, UnitRange
# (1:2:3,), # StepRange
# (1:2:3,:), (:,1:2:3), # StepRange,Colon
# ([1,3],), ([2,2],1,1), # Vector{Int}
# ([1,3],:), (:,[1,3]), # Vector{Int},Colon
# ([2,2],:), (:,[2,2]), # Repeated index
# ([],), # Empty Array
((a3.>0.5),), # BitArray
# ([1 3; 2 4],), # Array{Int}
(CartesianIndex(3,),), # CartesianIndex
(CartesianIndex(2,3,4),),
# (:,a3[1,:].>0.5), # BitArray2
# (a3[:,1].>0.5,:),
# ([CartesianIndex(2,2), CartesianIndex(2,1)],) # Array{CartesianIndex}
)
# @show i
@test a3[i...] == k3[i...]
ai = a3[i...]
a3[i...] = 0
k3[i...] = 0
@test a3 == k3
a3[i...] = ai
k3[i...] = ai
@test a3 == k3
@test gradcheck(getindex, a3, i...)
@test gradcheck(getindex, k3, i...)
end
# make sure end works
@test a3[2:end] == k3[2:end]
@test a3[2:end,2:end,2:end] == k3[2:end,2:end,2:end]
# k2.>0.5 returns KnetArray{T}, no Knet BitArrays yet
@test a3[a3.>0.5] == k3[k3.>0.5]
end
# AbstractArray interface
@testset "abstractarray" begin

for f in (copy, endof, first, isempty, length, ndims, ones, vec, zeros,
a->(eachindex(a);0), a->(eltype(a);0), # a->(Base.linearindexing(a);0),
a->collect(Float64,size(a)), a->collect(Float64,strides(a)),
a->cat(1,a,a), a->cat(2,a,a), a->hcat(a,a), a->vcat(a,a),
a->reshape(a,2,6), a->reshape(a,(2,6)),
a->size(a,1), a->size(a,2),
a->stride(a,1), a->stride(a,2), )
a2->(eachindex(a2);0), a2->(eltype(a2);0), # a2->(Base.linearindexing(a2);0),
a2->collect(Float64,size(a2)), a2->collect(Float64,strides(a2)),
a2->cat(1,a2,a2), a2->cat(2,a2,a2), a2->hcat(a2,a2), a2->vcat(a2,a2),
a2->reshape(a2,2,6), a2->reshape(a2,(2,6)),
a2->size(a2,1), a2->size(a2,2),
a2->stride(a2,1), a2->stride(a2,2), )

# @show f
@test f(a) == f(k)
@test gradcheck(f, a)
@test gradcheck(f, k)
@test f(a2) == f(k2)
@test gradcheck(f, a2)
@test gradcheck(f, k2)
end

@test convert(Array{Float32},a) == convert(KnetArray{Float32},k)
@test fill!(similar(a),pi) == fill!(similar(k),pi)
@test fill!(similar(a,(2,6)),pi) == fill!(similar(k,(2,6)),pi)
@test fill!(similar(a,2,6),pi) == fill!(similar(k,2,6),pi)
@test isa(pointer(k), Ptr{Float64})
@test isa(pointer(k,3), Ptr{Float64})
@test convert(Array{Float32},a2) == convert(KnetArray{Float32},k2)
@test fill!(similar(a2),pi) == fill!(similar(k2),pi)
@test fill!(similar(a2,(2,6)),pi) == fill!(similar(k2,(2,6)),pi)
@test fill!(similar(a2,2,6),pi) == fill!(similar(k2,2,6),pi)
@test isa(pointer(k2), Ptr{Float64})
@test isa(pointer(k2,3), Ptr{Float64})
@test isempty(KnetArray(Float32,0))
@test rand!(copy(a)) != rand!(copy(k))
@test k == k
@test a == k
@test k == a
@test isapprox(k,k)
@test isapprox(a,k)
@test isapprox(k,a)
@test a == copy!(similar(a),k)
@test k == copy!(similar(k),a)
@test k == copy!(similar(k),k)
@test k == copy(k)
@test pointer(k) != pointer(copy(k))
@test k == deepcopy(k)
@test pointer(k) != pointer(deepcopy(k))
@test rand!(copy(a2)) != rand!(copy(k2))
@test k2 == k2
@test a2 == k2
@test k2 == a2
@test isapprox(k2,k2)
@test isapprox(a2,k2)
@test isapprox(k2,a2)
@test a2 == copy!(similar(a2),k2)
@test k2 == copy!(similar(k2),a2)
@test k2 == copy!(similar(k2),k2)
@test k2 == copy(k2)
@test pointer(k2) != pointer(copy(k2))
@test k2 == deepcopy(k2)
@test pointer(k2) != pointer(deepcopy(k2))
end

@testset "cpu2gpu" begin
# cpu/gpu xfer with grad support
if VERSION >= v"0.6.0"
@test gradcheck(x->Array(sin.(KnetArray(x))),a)
@test gradcheck(x->KnetArray(sin.(Array(x))),k)
else
@test gradcheck(x->Array(sin(KnetArray(x))),a)
@test gradcheck(x->KnetArray(sin(Array(x))),k)
end
@test gradcheck(x->Array(sin.(KnetArray(x))),a2)
@test gradcheck(x->KnetArray(sin.(Array(x))),k2)
end

@testset "reshape" begin
a = KnetArray(Float32, 2, 2, 2)
a2 = KnetArray(Float32, 2, 2, 2)

@test size(reshape(a, 4, :)) == size(reshape(a, (4, :))) == (4, 2)
@test size(reshape(a, :, 4)) == size(reshape(a, (:, 4))) == (2, 4)
@test size(reshape(a, :, 1, 4)) == (2, 1, 4)
end
@test size(reshape(a2, 4, :)) == size(reshape(a2, (4, :))) == (4, 2)
@test size(reshape(a2, :, 4)) == size(reshape(a2, (:, 4))) == (2, 4)
@test size(reshape(a2, :, 1, 4)) == (2, 1, 4)
end
end
end

Expand Down