dfdx / Espresso.jl

Expression transformation package

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Size hints for derivatives with undefined size

dfdx opened this issue · comments

ds = rdiff(:(sum(W * x + b)); ctx=[:outfmt => :ein], W = rand(2,2), x=rand(2), b=rand(2))
ds[:W] 

Currently, this produces:

dtmp6_dW[m,n] = x[n]

This expression is a correct derivative for output w.r.t. to components of W, however, it doesn't define the size of dtmp6_dW. Thus we need to somehow pass size hints and use them when converting to vectorized form (to repmat in this case).

Current plan is:

  1. Create a new function transfer_size(primitive_expression) -> size_expression.
  2. Call it during forward pass and save into rdiff's context.
  3. Use it in from_einstein (have special syntax in templates?).