-
-
Notifications
You must be signed in to change notification settings - Fork 214
Gradient of reshape(::Array{Bool}, ...)
does not handle thunks
#1567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
MWE with only Zygote: julia> using Zygote
julia> let x = rand(Bool, 12)
w = rand(Float32, 4, 3)
gradient(w -> sum(w * reshape(x, 3, 4)), w)
end
ERROR: MethodError: no method matching reshape(::Nothing, ::Tuple{Int64})
...
Stacktrace:
[1] (::Zygote.var"#617#621"{Vector{Bool}, Tuple{Int64, Int64}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/lib/array.jl:107
[2] (::Zygote.var"#2783#back#623"{Zygote.var"#617#621"{…}})(Δ::ChainRulesCore.Thunk{ChainRules.var"#546#549"{…}})
@ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
[3] FluxML/Flux.jl#43
@ ./REPL[30]:3 [inlined]
[4] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface2.jl:0
[5] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:97
[6] gradient(f::Function, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/aLHDR/src/compiler/interface.jl:154
[7] top-level scope
@ REPL[30]:3
Some type information was truncated. Use `show(err)` to see complete types.
(@v1.11) pkg> st Zygote
Status `~/.julia/environments/v1.11/Project.toml`
[e88e6eb3] Zygote v0.7.5 Fails on 0.7.0, but worked before: julia> let x = rand(Bool, 12)
w = rand(Float32, 4, 3)
gradient(w -> sum(w * reshape(x, 3, 4)), w)
end
(Float32[2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0; 2.0 2.0 2.0],)
(jl_ZbRV0D) pkg> st Zygote
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_ZbRV0D/Project.toml`
⌅ [e88e6eb3] Zygote v0.6.75 The stacktrace points to this rule: Lines 106 to 107 in 1b914d9
PR #966 introduced The relevant code from FluxML/ZygoteRules.jl#17 is these lines which produce:
instead of something like:
|
Embedding
applied to 3D array failsreshape(::Array{Bool}, ...)
does not handle thunks
Originally:
Edit, see Zygote-only MWE below
I presume the problem is Zygote 0.7 and thunks, as it works fine on earlier versions:
Edit, with
Dense
the problem is only withOneHotArray
, and not withArray
:The text was updated successfully, but these errors were encountered: