Related: https://github.com/JuliaLang/julia/issues/9551
Unfortunately, as you've seen, type-variadic keyword arguments can really
mess up type-inferencing. It appears that keyword argument types are pulled
from the default arguments rather than those actually passed in at runtime:
*julia> f(x; a=1, b=2) = a*x^b*
*f (generic function with 1 method)*
*julia> f(1)*
*1*
*julia> f(1, a=(3+im), b=5.15)*
*3.0 + 1.0im*
*julia> @code_typed f(1, a=(3+im), b=5.15)*
*1-element Array{Any,1}:*
* :($(Expr(:lambda, Any[:x],
Any[Any[Any[:x,Int64,0]],Any[],Any[Int64],Any[]], :(begin $(Expr(:line, 1,
:none, symbol("")))*
* GenSym(0) = (Base.power_by_squaring)(x::Int64,2)::Int64*
* return (Base.box)(Int64,(Base.mul_int)(1,GenSym(0)))::Int64*
* end::Int64))))*
Obviously, that specific call to f does NOT return an Int64.
I know of only two reasonable ways to handle it at the moment:
1. If you're the method author: Restrict every keyword argument to a
declared, concrete type, which ensures that the argument isn't
type-variadic. Yichao basically gave an example of this.
2. If you're the method caller: Manually assert the return type. You can do
this pretty easily in most cases using a wrapper function.
Using `f` from above as an example:
*julia> g{X,A,B}(x::X, a::A, b::B) = f(x, a=a, b=b)::promote_type(X, A, B)*
*g (generic function with 2 methods)*
*julia> @code_typed g(1,2,3)*
*1-element Array{Any,1}:*
* :($(Expr(:lambda, Any[:x,:a,:b],
Any[Any[Any[:x,Int64,0],Any[:a,Int64,0],Any[:b,Int64,0]],Any[],Any[Int64],Any[:X,:A,:B]],
:(begin # none, line 1:*
* return
(top(typeassert))((top(kwcall))((top(getfield))(Main,:call)::F,2,:a,a::Int64,:b,b::Int64,Main.f,(top(ccall))(:jl_alloc_array_1d,(top(apply_type))(Base.Array,Any,1)::Type{Array{Any,1}},(top(svec))(Base.Any,Base.Int)::SimpleVector,Array{Any,1},0,4,0)::Array{Any,1},x::Int64),Int64)::Int64*
* end::Int64))))*
*julia> @code_typed g(1,2,3.0)*
*1-element Array{Any,1}:*
* :($(Expr(:lambda, Any[:x,:a,:b],
Any[Any[Any[:x,Int64,0],Any[:a,Int64,0],Any[:b,Float64,0]],Any[],Any[Int64],Any[:X,:A,:B]],
:(begin # none, line 1:*
* return
(top(typeassert))((top(kwcall))((top(getfield))(Main,:call)::F,2,:a,a::Int64,:b,b::Float64,Main.f,(top(ccall))(:jl_alloc_array_1d,(top(apply_type))(Base.Array,Any,1)::Type{Array{Any,1}},(top(svec))(Base.Any,Base.Int)::SimpleVector,Array{Any,1},0,4,0)::Array{Any,1},x::Int64),Float64)::Float64*
* end::Float64))))*
*julia> @code_typed g(1,2,3.0+im)*
*1-element Array{Any,1}:*
* :($(Expr(:lambda, Any[:x,:a,:b],
Any[Any[Any[:x,Int64,0],Any[:a,Int64,0],Any[:b,Complex{Float64},0]],Any[],Any[Int64],Any[:X,:A,:B]],
:(begin # none, line 1:*
* return
(top(typeassert))((top(kwcall))((top(getfield))(Main,:call)::F,2,:a,a::Int64,:b,b::Complex{Float64},Main.f,(top(ccall))(:jl_alloc_array_1d,(top(apply_type))(Base.Array,Any,1)::Type{Array{Any,1}},(top(svec))(Base.Any,Base.Int)::SimpleVector,Array{Any,1},0,4,0)::Array{Any,1},x::Int64),Complex{Float64})::Complex{Float64}*
* end::Complex{Float64}))))*
Thus, downstream functions can call *f* through *g, *preventing
type-instability from "bubbling up" to the calling methods (as it would if
they called *f* directly).
Best,
Jarrett
On Tuesday, September 1, 2015 at 8:39:11 AM UTC-4, Michael Francis wrote:
>
> 2) The underlying functions are only stable if the mean passed to them is
> of the correct type, e.g. a number. Essentially this is a type inference
> issue, if the compiler was able to optimize the branches then it would be
> likely be ok, it looks from the LLVM code that this is not the case today.
>
> FWIW using a type stable version (e.g. directly calling covm) looks to be
> about 18% faster for small (100 element) AbstractArray pairs.
>
> On Monday, August 31, 2015 at 9:06:58 PM UTC-4, Sisyphuss wrote:
>>
>> IMO:
>> 1) This is called keyword argument (not named optional argument).
>> 2) The returned value depends only on `corzm`, and `corm`. If these two
>> functions are type stable, then `cor` is type stable.
>> 3) I'm not sure whether this is the "correct" way to write this function.
>>
>> On Monday, August 31, 2015 at 11:48:37 PM UTC+2, Michael Francis wrote:
>>>
>>> The following is taken from statistics.jl line 428
>>>
>>> function cor(x::AbstractVector, y::AbstractVector; mean=nothing)
>>> mean == 0 ? corzm(x, y) :
>>> mean == nothing ? corm(x, Base.mean(x), y, Base.mean(y)) :
>>> isa(mean, (Number,Number)) ? corm(x, mean[1], y, mean[2]) :
>>> error("Invalid value of mean.")
>>> end
>>>
>>> due to the 'mean' initially having a type of 'Nothing' I am unable to
>>> inference the return type of the function - the following will return Any
>>> for the return type.
>>>
>>> rt = {}
>>> for x in Base._methods(f,types,-1)
>>> linfo = x[3].func.code
>>> (tree, ty) = Base.typeinf(linfo, x[1], x[2])
>>> push!(rt, ty)
>>> end
>>>
>>> Each of the underlying functions are type stable when called directly.
>>>
>>> Code lowered doesn't give much of a pointer to what will actually happen
>>> here,
>>>
>>> julia> code_lowered( cor, ( Vector{Float64}, Vector{Float64} ) )
>>> 1-element Array{Any,1}:
>>> :($(Expr(:lambda, {:x,:y}, {{},{{:x,:Any,0},{:y,:Any,0}},{}}, :(begin $
>>> (Expr(:line, 429, symbol("statistics.jl"), symbol("")))
>>> return __cor#195__(nothing,x,y)
>>> end))))
>>>
>>>
>>> If I re-write with a regular optional arg for the mean
>>>
>>> code_lowered( cordf, ( Vector{Float64}, Vector{Float64}, Nothing ) )
>>> 1-element Array{Any,1}:
>>> :($(Expr(:lambda, {:x,:y,:mean}, {{},{{:x,:Any,0},{:y,:Any,0},{:mean,:
>>> Any,0}},{}}, :(begin # none, line 2:
>>> unless mean == 0 goto 0
>>> return corzm(x,y)
>>> 0:
>>> unless mean == nothing goto 1
>>> return corm(x,((top(getfield))(Base,:mean))(x),y,((top(getfield
>>> ))(Base,:mean))(y))
>>> 1:
>>> unless isa(mean,(top(tuple))(Number,Number)) goto 2
>>> return corm(x,getindex(mean,1),y,getindex(mean,2))
>>> 2:
>>> return error("Invalid value of mean.")
>>> end))))
>>>
>>> The LLVM code does not look very clean, If I have a real type for the
>>> mean (say Float64 ) it looks better 88 lines vs 140
>>>
>>> julia> code_llvm( cor, ( Vector{Float64}, Vector{Float64}, Nothing ) )
>>>
>>>
>>> define %jl_value_t* @julia_cordf_20322(%jl_value_t*, %jl_value_t*, %
>>> jl_value_t*) {
>>> top:
>>> %3 = alloca [7 x %jl_value_t*], align 8
>>> %.sub = getelementptr inbounds [7 x %jl_value_t*]* %3, i64 0, i64 0
>>> %4 = getelementptr [7 x %jl_value_t*]* %3, i64 0, i64 2, !dbg !949
>>> store %jl_value_t* inttoptr (i64 10 to %jl_value_t*), %jl_value_t** %.
>>> sub, align 8
>>> %5 = getelementptr [7 x %jl_value_t*]* %3, i64 0, i64 1, !dbg !949
>>> %6 = load %jl_value_t*** @jl_pgcstack, align 8, !
>>> ...
>>
>>