denizyuret / AutoGrad.jl

Julia port of the Python autograd package.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

gradient of std(Array{Float32,N}) throws error

rened opened this issue · comments

For example, when adding the following lines to test/statistics.jl, 6 of these tests fail:

@test gradcheck(mean, randn(Float32,2,3))
@test gradcheck(mean, randn(Float32,2,3), kwargs=[:dims=>1])
@test gradcheck(mean, randn(Float32,2,3), kwargs=[:dims=>(1,2)])
@test gradcheck(meanabs, randn(Float32,2,3))
@test gradcheck(meanabs2, randn(Float32,2,3))
@test gradcheck(var, randn(Float32,2,3))
@test gradcheck(var, randn(Float32,2,3), kwargs=[:dims=>1])
@test gradcheck(var, randn(Float32,2,3), kwargs=[:dims=>(1,2)])
@test gradcheck(std, randn(Float32,2,3))
@test gradcheck(std, randn(Float32,2,3), kwargs=[:dims=>1])
@test gradcheck(std, randn(Float32,2,3), kwargs=[:dims=>(1,2)])

It seems that the eltype of the input data is not taken into account when allocating the output:

  Expression: gradcheck(var, randn(Float32, 2, 3), kwargs=[:dims => (1, 2)])
  MethodError: no method matching sum_outgrads(::Array{Float32,2}, ::Array{Float64,2})
  Closest candidates are:
    sum_outgrads(!Matched::Nothing, ::Any) at /home/rene/.julia/dev/AutoGrad/src/core.jl:499
    sum_outgrads(::AbstractArray{T,N} where N, !Matched::AbstractArray{T,N} where N) where T at /home/rene/.julia/dev/AutoGrad/src/core.jl:486
    sum_outgrads(!Matched::Rec, ::Any) at /home/rene/.julia/dev/AutoGrad/src/core.jl:490
    ...
  Stacktrace:
   [1] backward_pass(::Rec{Array{Float32,2}}, ::Rec{Float32}, ::Array{AutoGrad.Node,1}) at /home/rene/.julia/dev/AutoGrad/src/core.jl:252
   [2] (::getfield(AutoGrad, Symbol("##gradfun#1#2")){getfield(Main, Symbol("#g#54")){getfield(Main, Symbol("##g#52#53")){typeof(var)}},Int64})(::Base.Iterators.Pairs{Symbol,Tuple{Int64,Int64},Tuple{Symbol},NamedTuple{(:dims,),Tuple{Tuple{Int64,Int64}}}}, ::Function, ::
Array{Float32,2}) at /home/rene/.julia/dev/AutoGrad/src/core.jl:41

Fixed in #91

It fails if dims argument exists.

Yes, with #91

@test gradcheck(std, randn(Float64,2,3))
@test gradcheck(std, randn(Float32,2,3))

works now, but std/var does not work with dims for neither Float32 nor Float64:

@test gradcheck(std, randn(Float64,2,3), kwargs=[:dims=>1])  # fails
@test gradcheck(std, randn(Float32,2,3), kwargs=[:dims=>1])  # fails

Fixed this in latest master. Please check and if it is working add these tests to test/statistics.jl.

Thanks, this works! Will add tests.