Reverse Mode Auto Differentiation is Kind of Like a Lens

Edit: More cogent version here http://www.philipzucker.com/reverse-mode-differentiation-is-kind-of-like-a-lens-ii/

Warning: I’m using sketchy uncompiled Haskell pseudocode.

Auto-differentiation is writing a function that also computes the derivative alongside calculating its value. Function composition is done alongside applying the chain rule to the derivative part.

One way to do this is to use a “dual number”. Functions now take a tuple of values and derivatives.

The Jacobean of a function from R^n \rightarrow R^m is a m by n matrix. The chain rule basically says that you need to compose the matrices via multiplication when you compose the value functions.  This is the composition of the linear maps.

Conceptually, you initialize the process with a NxN identity matrix corresponding to the fact that $latex \partial x_i/\partial x_j=\delta_{ij}

Vectorized versions of scalar functions (maps) will often use diag

A couple points:

  1.  Since the Jacobean j is always going to be multiplied in composition, it makes sense to factor this out into a Monad structure (Applicative maybe? Not sure we need full Monad power).
  2. There is an alternative to using explicit Matrix data types for linear maps. We could instead represent the jacobeans using (Vector Double) -> Vector Double. The downside of this is that you can’t inspect elements. You need explicit matrices as far as I know to do Gaussian elimination and QR decomposition. You can sample the function to reconstitute the matrix if need be, but this is somewhat roundabout. On the other hand, if your only objective is to multiply matrices, one can use very efficient versions. Instead of an explicit dense NxN identity matrix, you can use the function id :: a -> a, which only does some minimal pointer manipulation or is optimized away. I think that since we are largely multiplying Jacobeans, this is fine.

 

What we’ve shown so far is Forward Mode.

When you multiply matrices you are free to associate them in any direction you like. (D(C(BA))) is the association we’re using right now. But you are free to left associate them. ((DC)B)A). You can write this is right associated form using the transpose ((DC)B)A)^T = (A^T(B^T(C^TD^T)))

This form is reverse mode auto differentiation. Its advantage is the number of computations you have to do and the intermediate values you have to hold. If one is going from many variables to a small result, this is preferable.

It is actually exactly the same in implementation except you reverse the order of composition of the derivatives. We forward compose value functions and reverse compose derivative functions (matrices).

drawing-2

We have CPSed our derivative matrices.

Really a better typed version would not unify all the objects into a. While we’ve chosen to use Vector Double as our type, if we could tell the difference between R^n and R^m at the type level the following would make more sense.

However, this will no longer be a monad. Instead you’ll have to specify a Category instance. The way I got down to this stuff is via reading Conal Elliott’s new Automatic Differentiation paper which heavily uses the category interface.  I was trying to remove the need to use constrained categories (it is possible, but I was bogged down in type errors) and make it mesh nice with hmatrix. Let me also mention that using the Arrow style operators *** and dup and &&& and fst, and clever currying that he mentions also seems quite nice here. The Tuple structure is nice for expressing direct sum spaces in matrices. (Vector a, Vector b) is the direct sum of those vector spaces.

Anyway, the arrows for RD are

This is a form I’ve seen before though. It is a lens. Lens’ have a getter (a -> b) that extracts b from a and a setter (a -> b -> a) that given an a and a new b returns the replaced a.

Is an automatic derivative function in some sense extracting an implicit calculable value from the original vector and returning in a sense how to change the original function? It is unclear whether one should take the lens analogy far or not.

The type of Lens’  (forall f. Functor f => (b -> f b) -> a -> f a) means that it is isomorphic to a type like DFun’. The type itself does imply the lens laws of setters and getters, so these functions are definitely not proper lawful lenses. It is just curious that conceptually they are not that far off.

The lens trick of replacing this function with a quantified rank 1 type (forall f. ) or quantified rank-2 (forall p.) profunctor trick seems applicable here. We can then compose reverse mode functions using the ordinary (.) operator and abuse convenience functions from the lens library.

Neat if true.

 

 

2 thoughts on “Reverse Mode Auto Differentiation is Kind of Like a Lens”

    1. Wow! Thanks! I’m so glad you saw this! My thoughts were significantly more scattered when I wrote this. I can’t even super follow some of my points in this post, but I think the lens and reverse mode analogy is still pretty solid.

      I actually explored this a bit more here https://github.com/philzook58/ad-lens but never wrote about it. One difficulty was writing the function that maintains the sharing when transforming between the a -> (b , db -> da) and the forall f. (a -> f b) -> s -> f t versions of lens, but I think I got it. It’s hard to tell though, since sharing isn’t readily observable without peeking under the hood.

      The four “stab” type parameters of the full Lens type are actually quite useful as they show the difference between differentials dx and dy and the actual coordinates x and y, which should be only carefully mixed.

      I’ve since decided that using the lens trick to be able to use the standard (.) operator is a bit silly and not worth it. I suspect the compiler will be able to shred through the more boring form, more easily finding optimizations, and the technique will still work in a language that may not have the ability to make a polymorphic f :: * -> * or other fancy Haskell abilities (Futhark?).

Leave a Reply

Your email address will not be published. Required fields are marked *