Greatly inspired by Andrej Karpathy's micrograd
-
utils and visualizations
-
Implement all operations:
Relu, Log, Exp # unary ops
Sum, Max # reduce ops (with axis argument)
Add, Sub, Mul, Pow. # binary ops (with broadcasting)
Reshape, Transpose, Slice # movement ops
Matmul, Conv2D # processing ops
Let's say we have expression π§=π₯1π₯2+sin(π₯1) and want to find derivatives ππ§ππ₯1 and ππ§ππ₯2. Reverse-mode AD splits this task into 2 parts, namely, forward and reverse passes.
First step is to decompose the complex expression into a set of primitive ones, i.e. expressions consisting of at most single step or single function call.
π€1 = π₯1
π€2 = π₯2
π€3 = π€1 * π€2
π€4 = sin(π€1)
π€5 = π€3 + π€4
π§ = π€5
The advantage of this representation is that differentiation rules for each separate expression are already known.
For example, we know that derivative of sin
is cos
, and so dw4/dw1=cosβ‘(w1)
.
We will use this fact in reverse pass below. Essentially, forward pass consists of evaluating each of these expressions and saving the results.
Say, our inputs are: π₯1=2
and π₯2=3
. Then we have:
π€1 = π₯1 = 2
π€2 = π₯2 = 3
π€3 = π€1 * π€2 = 6
π€4 = sin(π€1) = 0.9
π€5 = π€3 + π€4 = 6.9
π§ = π€5 = 6.9
This is the main part and it uses the chain rule.
In its basic form, chain rule states that if you have variable π‘(π’(π£))
which depends on π’
which, in its turn, depends on π£
, then:
ππ‘/ππ£ = ππ‘/ππ’ * ππ’/ππ£
or, if π‘
depends on π£
via several paths / variables π’π
, e.g.:
π’1 = π(π£)
π’2 = π(π£)
π‘ = β(π’1,π’2)
then:
ππ‘/ππ£ = β ππ‘/ππ’π *ππ’π/ππ£
In terms of expression graph, if we have a final node π§
and input nodes π€π
, and path from π§
to π€π
goes through intermediate nodes π€π
(i.e. π§=π(π€π) where π€π=π(π€π)), we can find derivative ππ§/ππ€π
as
ππ§/ππ€π = β{πβPπππππ‘π (π)} ππ§/ππ€π * ππ€π/ππ€π
In other words, to calculate the derivative of output variable π§ w.r.t. any intermediate or input variable π€π, we only need to know the derivatives of its parents and the formula to calculate derivative of primitive expression π€π=π(π€π).
Reverse pass starts at the end (i.e. ππ§/ππ§) and propagates backward to all dependencies.
ππ§ / ππ§ = 1
Then we know that π§=π€5 and so:
ππ§ / ππ€5 = 1
π€5 linearly depends on π€3 and π€4, so ππ€5/ππ€3=1 and ππ€5/ππ€4=1. Using the chain rule we find:
ππ§/ππ€3 = ππ§/ππ€5 Γ ππ€5/ππ€3 = 1Γ1 = 1
ππ§/ππ€4 = ππ§/ππ€5 Γ ππ€5/ππ€4 = 1Γ1 = 1
From definition π€3=π€1π€2 and rules of partial derivatives, we find that ππ€3 / ππ€2=π€1. Thus:
ππ§/ππ€2 = ππ§/ππ€3 Γ ππ€3/ππ€2 = 1 Γ π€1 = π€1
Which, as we already know from forward pass, is:
ππ§/ππ€2 = π€1 = 2
Finally, π€1
contributes to π§
via π€3
and π€4
. Once again, from the rules of partial derivatives we know that ππ€3/ππ€1 = π€2
and ππ€4/ππ€1 = cos(π€1)
. Thus:
ππ§/ππ€1 = ππ§/ππ€3 * ππ€3/ππ€1 + ππ§/ππ€4 * ππ€4/ππ€1 = π€2 + cos(π€1)
And again, given known inputs, we can calculate it:
ππ§/ππ€1 = π€2 + cos(π€1) = 3 + cos(2) = 2.58
Since π€1 and π€2 are just aliases for π₯1 and π₯2, we get our answer:
ππ§ / ππ₯1 = 2.58
ππ§ / ππ₯2 = 2
And all is done for the given expression!