Skip to content

Gradient of reshape(::Array{Bool}, ...) does not handle thunks #1567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mcabbott opened this issue Apr 2, 2025 · 1 comment · May be fixed by FluxML/ZygoteRules.jl#35
Open

Gradient of reshape(::Array{Bool}, ...) does not handle thunks #1567

mcabbott opened this issue Apr 2, 2025 · 1 comment · May be fixed by FluxML/ZygoteRules.jl#35
Labels
bug Something isn't working

Comments

@mcabbott
Copy link
Member

mcabbott commented Apr 2, 2025

Originally:

julia> using Flux

julia> let e = Embedding(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           # x = Array(x)  # similar error with Array or OneHotArray
           Flux.gradient(m -> sum(abs2, m(x)), e)
       end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})

Edit, see Zygote-only MWE below

Closest candidates are:
  reshape(::ChainRulesCore.AbstractThunk, ::Any...)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:62
  reshape(::Array{T, M}, ::NTuple{N, Int64}) where {T, N, M}
   @ Base reshapedarray.jl:40
  reshape(::BitArray{N}, ::NTuple{N, Int64}) where N
   @ Base bitarray.jl:479
  ...

Stacktrace:
  [1] (::Zygote.var"#617#621"{OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Int64, Colon}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
  [2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
  [3] Embedding
    @ ~/.julia/packages/Flux/3711C/src/layers/basic.jl:776 [inlined]
  [4] (::Zygote.Pullback{…})(Δ::ChainRulesCore.InplaceableThunk{…})
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
  [5] FluxML/Flux.jl#197
    @ ./REPL[408]:4 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
  [7] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:97
  [8] gradient(f::Function, args::Embedding{Matrix{Float32}})
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:154
  [9] #gradient#1
    @ ~/.julia/packages/Flux/3711C/src/gradient.jl:44 [inlined]
 [10] gradient(f::Function, args::Embedding{Matrix{Float32}})
    @ Flux ~/.julia/packages/Flux/3711C/src/gradient.jl:31
 [11] top-level scope
    @ REPL[408]:4
Some type information was truncated. Use `show(err)` to see complete types.

(@v1.11) pkg> st Flux Zygote
Status `~/.julia/environments/v1.11/Project.toml`
  [587475ba] Flux v0.16.3
  [e88e6eb3] Zygote v0.7.5

I presume the problem is Zygote 0.7 and thunks, as it works fine on earlier versions:

julia> let e = Embedding(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           # x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), e)
       end
((weight = Float32[6.834647 3.3733022; 5.7237077 0.9229657],),)

julia> let e = Embedding(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), e)
       end
((weight = Float32[1.961737 -1.5491782; -0.6510874 11.824801],),)

(jl_ZbRV0D) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_ZbRV0D/Project.toml`
⌃ [587475ba] Flux v0.14.25
⌅ [e88e6eb3] Zygote v0.6.75

Edit, with Dense the problem is only with OneHotArray, and not with Array:

julia> let d = Dense(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), d)
       end
((weight = Float32[0.6652966 -3.0755887; 1.8529012 2.833063], bias = Float32[-2.4102921, 4.685964], σ = nothing),)

julia> let d = Dense(2=>2)
           x = Flux.onehotbatch([1 2; 2 1], 1:2)
           # x = Array(x)
           Flux.gradient(m -> sum(abs2, m(x)), d)
       end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64, Int64, Int64})
The function `reshape` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  reshape(::ChainRulesCore.AbstractThunk, ::Any...)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/U6wNx/src/tangent_types/thunks.jl:62
  reshape(::Array{T, M}, ::NTuple{N, Int64}) where {T, N, M}
   @ Base reshapedarray.jl:40
  reshape(::BitArray{N}, ::NTuple{N, Int64}) where N
   @ Base bitarray.jl:479
  ...

Stacktrace:
  [1] (::Zygote.var"#617#621"{OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}}, Tuple{Int64, Colon}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
  [2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
  [3] Dense
    @ ~/.julia/packages/Flux/3711C/src/layers/basic.jl:204 [inlined]
  [4] (::Zygote.Pullback{…})(Δ::ChainRulesCore.InplaceableThunk{…})
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
  [5] FluxML/Flux.jl#215
    @ ./REPL[419]:4 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
@mcabbott mcabbott added the bug Something isn't working label Apr 2, 2025
@mcabbott
Copy link
Member Author

mcabbott commented Apr 3, 2025

MWE with only Zygote:

julia> using Zygote

julia> let x = rand(Bool, 12)
           w = rand(Float32, 4, 3)
           gradient(w -> sum(w * reshape(x, 3, 4)), w)
       end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64})
  ...
Stacktrace:
 [1] (::Zygote.var"#617#621"{Vector{Bool}, Tuple{Int64, Int64}})(Δ::Nothing)
   @ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
 [2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
   @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
 [3] FluxML/Flux.jl#43
   @ ./REPL[30]:3 [inlined]
 [4] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
   @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
 [5] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
   @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:97
 [6] gradient(f::Function, args::Matrix{Float32})
   @ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:154
 [7] top-level scope
   @ REPL[30]:3
Some type information was truncated. Use `show(err)` to see complete types.

(@v1.11) pkg> st Zygote
Status `~/.julia/environments/v1.11/Project.toml`
  [e88e6eb3] Zygote v0.7.5

Fails on 0.7.0, but worked before:

julia> let x = rand(Bool, 12)
           w = rand(Float32, 4, 3)
           gradient(w -> sum(w * reshape(x, 3, 4)), w)
       end
(Float32[2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0],)

(jl_ZbRV0D) pkg> st Zygote
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_ZbRV0D/Project.toml`
⌅ [e88e6eb3] Zygote v0.6.75

The stacktrace points to this rule:

Zygote.jl/src/lib/array.jl

Lines 106 to 107 in 1b914d9

@adjoint reshape(xs, dims...) = reshape(xs, dims...),
Δ -> (reshape(Δ, size(xs)),map(_->nothing,dims)...)

PR #966 introduced @_adjoint_keepthunks, after which @adjoint is supposed not to keep them. The Thunk is indeed being converted to nothing, perhaps too late to prevent the backward function from being called at all?

The relevant code from FluxML/ZygoteRules.jl#17 is these lines which produce:

back(::Nothing) = nothing
back(Δ) = $gradtuple(_back(unthunk_tangent(Δ))

instead of something like:

back(::Nothing) = nothing
back(Δ::AbstractThunk) = back(unthunk_tangent(Δ))
back(Δ) = $gradtuple(_back(Δ))

@mcabbott mcabbott changed the title Gradient of Embedding applied to 3D array fails Gradient of reshape(::Array{Bool}, ...) does not handle thunks Apr 3, 2025
@mcabbott mcabbott transferred this issue from FluxML/Flux.jl Apr 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant