Uniform Continuity is Kind of Like a Lens

A really interesting topic is exact real arithmetic. It turns out, there are systematic ways of calculating numerical results with arbitrarily fine accuracy.

In practice this is not used much as it is complicated and slow.

There are deep waters here.

The problem is made rather difficult by the fact that you can’t compute real numbers strictly, you have to in some sense compute better and better finite approximations.

One way of doing this is to compute a stream of arbitrarily good approximations. If someone needs a better approximation than you’ve already given, they pop the next one off.

Streams give you some inverted control flow. They allow the results to pull on the input, going against the grain of the ordinary direction of computation. If you are interested in a final result of a certain accuracy, they seem somewhat inefficient. You have to search for the right amount to pull the incoming streams, and the intermediate computations may not be helpful.

Haskell chews infinite lists up for breakfast, so it’s a convenient place for such things https://wiki.haskell.org/Exact_real_arithmetic https://hackage.haskell.org/package/exact-real

A related but slightly different set of methods comes in the form of interval arithmetic. Interval arithmetic also gives precise statements of accuracy, maintain bounds of the accuracy as a number is carried along

Interval arithmetic is very much like forward mode differentiation. In forward mode differentiation, you compute on dual numbers (x,dx) and carry along the derivatives as you go.

type ForwardMode x dx y dy = (x,dx) -> (y,dy)
type IntervalFun x delx y dely = (x,delx) -> (y, dely)

Conceptually, differentiation and these validated bounds are connected as well. They are both telling you something about how the function is behaving nearby. The derivative is mostly meaningful at exactly the point it is evaluated. It is extremely local. The verified bounds being carried along are sort of a very principled finite difference approximation.

But reverse mode differentiation is often where it is at. This is the algorithm that drives deep learning. Reverse mode differentiation can be modeled functionally as a kind of lens. http://www.philipzucker.com/reverse-mode-differentiation-is-kind-of-like-a-lens-ii/ . The thing that makes reverse mode confusing is the backward pass. This is also inverted control flow, where the output pushes information to the input. The Lens structure does this too

type Lens s t a b = s -> (a, b -> t)

It carrier a function that goes in the reverse direction which are being composed in the opposite direction of ordinary control flow. These functions are the “setters” in the ordinary usage of the Lens, but they are the backproppers for differentiation.

By analogy one might try

type RealF x delta y epsilon = Lens x delta y epsilon = x -> (y, epsilon -> delta)

There is something pleasing here compared to interval arithmetic in that the output epsilon drives the input delta. The second function is kind of a Skolemized \delta(\epsilon) from the definition of continuity.

Although it kind of makes sense, there is something unsatisfying about this. How do you compute the x -> y? You already need to know the accuracy before you can make this function?

So it seems to me that actually a better definition is

type RealF x delta y epsilon = Lens epsilon y delta x  = epsilon -> (delta, x -> y)

This type surprised me and is rather nice in many respects. It let’s you actually calculate x -> y, has that lazy pull based feel without infinite streams, and has delta as a function of epsilon.

I have heard, although don’t understand, that uniform continuity is the more constructive definition (see constructive analysis by Bridger) https://en.wikipedia.org/wiki/Uniform_continuity This definition seems to match that.

In addition we are able to use approximations of the actual function if we know the accuracy it needs to be computed to. For example, given we know we need 0.01 accuracy of the output, we know we only need 0.009 accuracy in the input and we only need the x term of a Taylor series of sine (the total inaccuracy of the input and the inaccuracy of our approximation of sine combine to give total inaccuracy of output). If we know the needed accuracy allows it, we can work with fast floating point operations. If we need better we can switch over to mpfr, etc.

This seems nice for MetaOcaml staging or other compile time macro techniques. If the epsilon required is known at compile time, it makes sense to me that one could use MetaOcaml to produce fast unrolled code. In addition, if you know the needed accuracy you can switch between methods and avoid the runtime overhead. The stream based approach seems to have a lot of context switching and perhaps unnecessary intermediate computations. It isn’t as bad as it seems, since these intermediate computations are usually necessary to compute anyhow, but still.

We can play the same monoidal category games with these lenses as ever. We can use dup, par, add, mul, sin, cos etc. and wire things up in diagrams and what have you.

This might be a nice type for use in a theorem prover. The Lens type combined with the appropriate properties that the intervals go to zero and stay consistent for arbitrary epsilon seems like enough? { Realf | something something something}

Relation to Backwards error analysis?

Does this have nice properties like backprop when on high dimensional inputs? That’s where backprop really shines, high to low dimensional functions

A Sketch of Gimped Interval Propagation with Lenses

David Sanders (who lives in Julia land https://github.com/JuliaIntervals ) explained a bit of how interval constraint propagation library worked to me last night. He described it as being very similar to backpropagation, which sets off alarm bells for me.

Backpropagation can be implemented in a point-free functional style using the lens pattern. http://www.philipzucker.com/reverse-mode-differentiation-is-kind-of-like-a-lens-ii/ Lenses are generally speaking a natural way to express in a functional style forward-backward pass algorithm that shares information between the two passes .

I also note Conal Elliot explicitly mentions interval computation in his compiling to categories work http://conal.net/papers/compiling-to-categories/ https://github.com/conal/concat and he does have something working there.

Interval arithmetic itself has already been implemented in Haskell in Ed Kmett’s interval package. https://hackage.haskell.org/package/intervals-0.9.1/docs/Numeric-Interval.html so we can just use that.

The interesting thing the backward pass gives you is that everything feels a bit more relational rather than functional. The backward pass allows you to infer new information using constraints given down the line. For example, fuse :: Lens (a,a) a let’s you enforce that two variables we actually equal. The lens pattern lets you store the forward pass intervals in a closure, so that you can intersect it with the backwards pass intervals.

I make no guarantees what I have here is right. It’s a very rough first pass. It compiles, so that is cool I guess.

Here’s my repo in case I fix more things up and you wanna check it out https://github.com/philzook58/ad-lens/blob/master/src/Numeric/ADLens/Interval.hs

Now having said that, to my knowledge Propagators are a more appropriate technique for this domain. https://www.youtube.com/watch?v=s2dknG7KryQ https://www.youtube.com/watch?v=nY1BCv3xn24 I don’t really know propagators though. It’s on my to do list.

Lens has a couple problems. It is probably doing way more work than it should, and we aren’t iterating to a fixed point.

Maybe an iterated lens would get us closer?

data Lens s t a b = Lens (a -> (b , (b -> (a, Lens s t a b))))

This is one way to go about the iterative process of updating a neural network in a functional way by evaluating it over and over and backpropagating. The updated weights will be stored in those closures. It seems kind of nice. It is clearly some relative of Iteratees and streaming libraries like pipes and conduit (which are also a compositional bidirectional programming pattern), the main difference being that it enforces a particular ordering of passes (for better or worse). Also I haven’t put in any monadic effects, which is to some degree the point of those libraries, but also extremely conceptually clouding to what is going on.

Another interesting possiblity is the type

type Lens s t a b = s -> (a, b -> t)

Lens s (Interval s) a (Interval a)

This has pieces that might be helpful for talking about continuous functions in a constructive way. It has the forward definition of the function, and then the inverse image of intervals. The inverse image function depends on the original evaluation point? Does this actually make sense? The definition of continuity is that this inverse image function must make arbitrarily small image intervals as you give it smaller and smaller range intervals. Continuity is compositional and plays nice with many arithmetic and structural combinators. So maybe something like this might be a nice abstraction for proof carrying continuous functions in Coq or Agda? Pure conjecture.

Neural Networks with Weighty Lenses (DiOptics?)

I wrote a while back how you can make a pretty nice DSL for reverse mode differentiation based on the same type as Lens. I’d heard some interesting rumblings on the internet around these ideas and so was revisiting them.

type Lens s t a b = s -> (a, b -> t)
type AD x dx y dy = x -> (y, dy -> dx)

Composition is defined identically for reverse mode just as it is for lens.

The forward computation shares info with the backwards differential propagation, which corresponds to a transposed Jacobian

After chewing on it a while, I realized this really isn’t that exotic. How it works is that you store the reverse mode computation graph, and all necessary saved data from the forward pass in the closure of the (dy -> dx). I also have a suspicion that if you defunctionalized this construction, you’d get the Wengert tape formulation of reverse mode ad.

Second, Lens is just a nice structure for bidirectional computation, with one forward pass and one backward pass which may or may not be getting/setting. There are other examples for using it like this.

It is also pretty similar to the standard “dual number” form type FAD x dx y dy = (x,dx)->(y,dy) for forward mode AD. We can bring the two closer by a CPS/Yoneda transformation and then some rearrangement.

     x -> (y, dy -> dx) 
==>  x -> (y, forall s. (dx -> s) -> (dy -> s))
==>  forall s. (x, dx -> s) -> (y, dx -> s) 

and meet it in the middle with

(x,dx) -> (y,dy)
==> forall s. (x, s -> dx) -> (y, s -> dy)

I ended the previous post somewhat unsatisfied by how ungainly writing that neural network example was, and I called for Conal Elliot’s compiling to categories plugin as a possible solution. The trouble is piping the weights all over the place. This piping is very frustrating in point-free form, especially when you know it’d be so trivial pointful. While the inputs and outputs of layers of the network compose nicely (you no longer need to know about the internal computations), the weights do not. As we get more and more layers, we get more and more weights. The weights are in some sense not as compositional as the inputs and outputs of the layers, or compose in a different way that you need to maintain access to.

I thought of a very slight conceptual twist that may help.

The idea is we keep the weights out to the side in their own little type parameter slots. Then we define composition such that it composes input/outputs while tupling the weights. Basically we throw the repetitive complexity appearing in piping the weights around into the definition of composition itself.

These operations are easily seen as 2 dimensional diagrams.

Three layers composed, exposing the weights from all layers
The 2-D arrow things can be built out of the 1-d arrows of the original basic AD lens by bending the weights up and down. Ultimately they are describing the same thing

Here’s the core reverse lens ad combinators

import Control.Arrow ((***))

type Lens'' a b = a -> (b, b -> a)

comp :: (b -> (c, (c -> b))) -> (a -> (b, (b -> a))) -> (a -> (c, (c -> a)))
comp f g x = let (b, dg) = g x in
             let (c, df) = f b in
             (c, dg . df)

id' :: Lens'' a a
id' x = (x, id) 

relu' :: (Ord a, Num a) => Lens'' a a
relu' = \x -> (frelu x, brelu x) where
        frelu x | x > 0 = x
                | otherwise = 0
        brelu x dy | x > 0 = dy
                   | otherwise = 0

add' :: Num a => Lens'' (a,a) a 
add' = \(x,y) -> (x + y, \ds -> (ds, ds))

dup' :: Num a => Lens'' a (a,a)
dup' = \x -> ((x,x), \(dx,dy) -> dx + dy)

sub' :: Num a => Lens'' (a,a) a 
sub' = \(x,y) -> (x - y, \ds -> (ds, -ds))

mul' :: Num a => Lens'' (a,a) a 
mul' = \(x,y) -> (x * y, \dz -> (dz * y, x * dz))

recip' :: Fractional a => Lens'' a a 
recip' = \x-> (recip x, \ds -> - ds / (x * x))

div' :: Fractional a => Lens'' (a,a) a 
div' = (\(x,y) -> (x / y, \d -> (d/y,-x*d/(y * y))))

sin' :: Floating a => Lens'' a a
sin' = \x -> (sin x, \dx -> dx * (cos x))

cos' :: Floating a => Lens'' a a
cos' = \x -> (cos x, \dx -> -dx * (sin x))

pow' :: Num a => Integer -> Lens'' a a
pow' n = \x -> (x ^ n, \dx -> (fromInteger n) * dx * x ^ (n-1)) 

--cmul :: Num a => a -> Lens' a a
--cmul c = lens (* c) (\x -> \dx -> c * dx)

exp' :: Floating a => Lens'' a a
exp' = \x -> let ex = exp x in
                      (ex, \dx -> dx * ex)

fst' :: Num b => Lens'' (a,b) a
fst' = (\(a,b) -> (a, \ds -> (ds, 0)))

snd' :: Num a => Lens'' (a,b) b
snd' = (\(a,b) -> (b, \ds -> (0, ds)))

-- some monoidal combinators
swap' :: Lens'' (a,b) (b,a)
swap' = (\(a,b) -> ((b,a), \(db,da) -> (da, db)))

assoc' :: Lens'' ((a,b),c) (a,(b,c))
assoc' = \((a,b),c) -> ((a,(b,c)), \(da,(db,dc)) -> ((da,db),dc))

assoc'' :: Lens'' (a,(b,c)) ((a,b),c)
assoc'' = \(a,(b,c)) -> (((a,b),c), \((da,db),dc)->  (da,(db,dc)))

par' :: Lens'' a b -> Lens'' c d -> Lens'' (a,c) (b,d)
par' l1 l2 = l3 where
    l3 (a,c) = let (b , j1) = l1 a in
               let (d, j2) = l2 c in
               ((b,d) , j1 *** j2) 
first' :: Lens'' a b -> Lens'' (a, c) (b, c)
first' l = par' l id'

second' :: Lens'' a b -> Lens'' (c, a) (c, b)
second' l = par' id' l

labsorb :: Lens'' ((),a) a
labsorb (_,a) = (a, \a' -> ((),a'))

labsorb' :: Lens'' a ((),a)
labsorb' a = (((),a), \(_,a') -> a')

rabsorb :: Lens'' (a,()) a
rabsorb = comp labsorb swap'

And here are the two dimensional combinators. I tried to write them point-free in terms of the combinators above to demonstrate that there is no monkey business going on. We

type WAD' w w' a b = Lens'' (w,a) (w',b)
type WAD'' w a b = WAD' w () a b -- terminate the weights for a closed network
{- For any monoidal category we can construct this composition? -}
-- horizontal composition
hcompose :: forall w w' w'' w''' a b c. WAD' w' w'' b c -> WAD' w w''' a b -> WAD' (w',w) (w'',w''') a c
hcompose f g = comp f' g' where 
               f' :: Lens'' ((w',r),b) ((w'',r),c)
               f' = (first' swap') `comp` assoc'' `comp` (par' id' f) `comp` assoc' `comp`  (first' swap') 
               g' :: Lens'' ((r,w),a) ((r,w'''),b)
               g' = assoc'' `comp` (par' id' g) `comp` assoc' 



rotate :: WAD' w w' a b -> WAD' a b w w'                                      
rotate f = swap' `comp` f `comp` swap'

-- vertical composition of weights
vcompose :: WAD' w'  w'' c d -> WAD' w w' a b -> WAD' w w'' (c, a) (d, b)
vcompose f g = rotate (hcompose (rotate f)  (rotate g) )                             

-- a double par.
diagpar :: forall w w' a b w'' w''' c d. WAD' w  w' a b -> WAD' w'' w''' c d 
           -> WAD' (w,w'') (w',w''') (a, c) (b, d)
diagpar f g = t' `comp` (par' f g) `comp` t where
                t :: Lens'' ((w,w''),(a,c)) ((w,a), (w'',c)) -- yikes. just rearrangements.
                t =  assoc'' `comp` (second' ((second' swap') `comp` assoc' `comp` swap')) `comp` assoc'
                t' :: Lens'' ((w',b), (w''',d)) ((w',w'''),(b,d)) -- the tranpose of t
                t' =  assoc'' `comp` (second'  ( swap'  `comp` assoc'' `comp` (second' swap')))  `comp` assoc'

id''' :: WAD' () () a a
id''' = id'






-- rotate:: WAD' w a a w
-- rotate = swap'

liftIO :: Lens'' a b -> WAD' w w a b
liftIO = second'

liftW :: Lens'' w w' -> WAD' w w' a a
liftW = first'


wassoc' = liftW assoc' 
wassoc'' = liftW assoc'' 

labsorb'' :: WAD' ((),w) w a a
labsorb'' = first' labsorb

labsorb''' :: WAD' w ((),w) a a
labsorb''' = first' labsorb'

wswap' :: WAD' (w,w') (w',w) a a
wswap' = first' swap'
-- and so on we can lift all combinators

I wonder if this is actually nice?

I asked around and it seems like this idea may be what davidad is talking about when he refers to dioptics

http://events.cs.bham.ac.uk/syco/strings3-syco5/slides/dalrymple.pdf

Perhaps this will initiate a convo.

Edit: He confirms that what I’m doing appears to be a dioptic. Also he gave a better link http://events.cs.bham.ac.uk/syco/strings3-syco5/papers/dalrymple.pdf

He is up to some interesting diagrams

https://twitter.com/davidad/status/1179760373030801408?s=20

Bits and Bobbles

  • Does this actually work or help make things any better?
  • Recurrent neural nets flip my intended role of weights and inputs.
  • Do conv-nets naturally require higher dimensional diagrams?
  • This weighty style seems like a good fit for my gauss seidel and iterative LQR solvers. A big problem I hit there was getting all the information to the outside, which is a similar issue to getting the weights around in a neural net.

Concolic Weakest Precondition is Kind of Like a Lens

That’s a mouthful.

Lens are described as functional getters and setters. The simple lens type is

type Lens a b = a -> (b, b -> a)

. The setter is

a->b

and the getter is

a -> b -> a

This type does not constrain lenses to obey the usual laws of getters and setters. So we can use/abuse lens structures for nontrivial computations that have forward and backwards passes that share information. Jules Hedges is particular seems to be a proponent for this idea.

I’ve described before how to encode reverse mode automatic differentiation in this style. I have suspicions that you can make iterative LQR and guass-seidel iteration have this flavor too, but I’m not super sure. My attempts ended somewhat unsatisfactorily a whiles back but I think it’s not hopeless. The trouble was that you usually want the whole vector back, not just its ends.

I’ve got another example in imperative program analysis that kind of makes sense and might be useful though. Toy repo here: https://github.com/philzook58/wp-lens

In program analysis it sometimes helps to run a program both concretely and symbolically. Concolic = CONCrete / symbOLIC. Symbolic stuff can slowly find hard things and concrete execution just sprays super fast and can find the dumb things really quick.  

We can use a lens structure to organize a DSL for describing a simple imperative language

The forward pass is for the concrete execution. The backward pass is for transforming the post condition to a pre condition in a weakest precondition analysis. Weakest precondition semantics is a way of specifying what is occurring in an imperative language. It tells how each statement transforms post conditions (predicates about the state after the execution) into pre conditions (predicates about before the execution).  The concrete execution helps unroll loops and avoid branching if-then-else behavior that would make the symbolic stuff harder to process. I’ve been flipping through Djikstra’s book on this. Interesting stuff, interesting man.

I often think of a state machine as a function taking s -> s. However, this is kind of restrictive. It is possible to have heterogenous transformations s -> s’. Why not? I think I am often thinking about finite state machines, which we really don’t intend to have a changing state size. Perhaps we allocated new memory or something or brought something into or out of scope. We could model this by assuming the memory was always there, but it seems wasteful and perhaps confusing. We need to a priori know everything we will need, which seems like it might break compositionally.

We could model our language making some data type like
data Imp = Skip | Print String | Assign String Expr | Seq Imp Imp | ...
and then build an interpreter

interp :: Imp -> s -> s'

Imp.

But we can also cut out the middle man and directly define our language using combinators.

type Stmt s s' = s ->s'

To me this has some flavor of a finally tagless style.


Likewise for expressions. Expressions evaluate to something in the context of the state (they can lookup variables), so let’s just use

type Expr s a = s -> a

And, confusingly (sorry), I think it makes sense to use Lens in their original getter/setter intent for variables. So Lens structure is playing double duty.

type Var s a = Lens' s a

With that said, here we go.


type Stmt s s' = s -> s' 
type Lens' a b = a -> (b, b -> a)
set l s a = let (_, f) = l s in f a

type Expr s a = s -> a
type Var s a = Lens' s a

skip :: Stmt s s
skip = id

sequence :: Stmt s s' -> Stmt s' s'' -> Stmt s s''
sequence = flip (.)

assign :: Var s a -> Expr s a -> Stmt s s
assign v e = \s -> set v s (e s)

(===) :: Var s a -> Expr s a -> Stmt s s
v === e = assign v e

ite :: Expr s Bool -> Stmt s s' -> Stmt s s' -> Stmt s s'
ite e stmt1 stmt2 = \s -> if (e s) then stmt1 s else stmt2 s

while :: Expr s Bool -> Stmt s s -> Stmt s s
while e stmt = \s -> if (e s) then ((while e stmt) (stmt s)) else s

assert :: Expr s Bool -> Stmt s s  
assert e = \s -> if (e s) then s else undefined 

abort :: Stmt s s'  
abort = const undefined

Weakest precondition can be done similarly, instead we start from the end and work backwards

Predicates are roughly sets. A simple type for sets is

type Pred s = s -> Bool 
Now, this doesn’t have much deductive power, but I think it demonstrates the principles simply. We could replace Pred with perhaps an SMT solver expression, or some data type for predicates, for which we’ll need to implement things like substitution. Let’s not today.

A function

a -> b 
is equivalent to
forall c. (b -> c) -> (a -> c)
. This is some kind of CPS / Yoneda transformation thing. A state transformer
s -> s'
to predicate transformer
(s' -> Bool) -> (s -> Bool)
is somewhat evocative of that. I’m not being very precise here at all.

Without further ado, here’s how I think a weakest precondition looks roughly.


type Lens' a b = a -> (b, b -> a)
set l s a = let (_, f) = l s in f a

type Expr s a = s -> a
type Var s a = Lens' s a
type Pred s = s -> Bool
type Stmt s s' = Pred s' -> Pred s 

skip :: Stmt s s
skip = \post -> let pre = post in pre -- if

sequence :: Stmt s s' -> Stmt s' s'' -> Stmt s s''
sequence = (.)

assign :: Var s a -> Expr s a -> Stmt s s
assign v e = \post -> let pre s = post (set v s (e s)) in pre

(===) :: Var s a -> Expr s a -> Stmt s s
v === e = assign v e

ite :: Expr s Bool -> Stmt s s' -> Stmt s s' -> Stmt s s'
ite e stmt1 stmt2 = \post -> let pre s = if (e s) then (stmt1 post) s else (stmt2 post) s in pre

abort :: Stmt s s'  
abort = \post -> const False

assert :: Expr s Bool -> Stmt s s  
assert e = \post -> let pre s = (e s) && (post s) in pre

{-
-- tougher. Needs loop invariant
while :: Expr s Bool -> Stmt s s -> Stmt s s
while e stmt = \post -> let pre s = if (e s) then ((while e stmt) (stmt post)) s else  in pre
-}

Finally here is a combination of the two above that uses the branching structure of the concrete execution to aid construction of the precondition. Although I haven’t expanded it out, we are using the full s t a b parametrization of lens in the sense that states go forward and predicates come back.


type Lens' a b = a -> (b, b -> a)
set l s a = let (_, f) = l s in f a


type Expr s a = s -> a
type Var s a = Lens' s a
type Pred a = a -> Bool
type Stmt s s' = s -> (s', Pred s' -> Pred s) -- eh. Screw the newtype

skip :: Stmt s s
skip = \x -> (x, id)


sequence :: Stmt s s' -> Stmt s' s'' -> Stmt s s''
sequence f g =   \s -> let (s', j) = f s in
                       let (s'', j') = g s' in
                           (s'', j . j')
assign :: Var s a -> Expr s a -> Stmt s s
assign v e = \s -> (set v s (e s), \p -> \s -> p (set v s (e s)))

--if then else
ite :: Expr s Bool -> Stmt s s' -> Stmt s s' -> Stmt s s'
ite e stmt1 stmt2 = \s -> 
                    if (e s) 
                    then let (s', wp) = stmt1 s in
                         (s', \post -> \s -> (e s) && (wp post s))
                    else let (s', wp) = stmt2 s in
                            (s', \post -> \s -> (not (e s)) && (wp post s))

assert :: Pred s -> Stmt s s
assert p = \s -> (s, \post -> let pre s = (post s) && (p s) in pre)

while :: Expr s Bool -> Stmt s s -> Stmt s s
while e stmt = \s -> if e s then let (s' , wp) = (while e stmt) s in
                                 (s', \post -> let pre s'' = (post s'') && (wp post s'') in pre)   
                            else (s, \p -> p)

{-

-- declare and forget can change the size and shape of the state space.
-- These are heterogenous state commpands
declare :: Iso (s,Int) s' -> Int -> Stmt s s'   
declare iso defalt = (\s -> to iso (s, defalt), \p -> \s -> p $ to iso (s, defalt)) 

forget :: Lens' s s' -> Stmt s s' -- forgets a chunk of state

declare_bracket :: Iso (s,Int) s' -> Int ->  Stmt s' s' -> Stmt s s
declare_bracket iso defalt stmt = (declare iso default) . stmt . (forget (_1 . iso))

Neat. Useful? Me dunno.

Applicative Bidirectional Programming and Automatic Differentiation

I got referred to an interesting paper by a comment of /u/syrak.

http://www2.sf.ecei.tohoku.ac.jp/~kztk/papers/kztk_jfp_am_2018.pdf

Applicative bidirectional programming (PDF), by Kazutaka Matsuda and Meng Wang

In it, they use a couple interesting tricks to make Lens programming more palatable. Lens often need to be be programmed in a point free style, which is rough, but by using some combinators, they are able to program lenses in a pointful style (with some pain points still left over). It is a really interesting, well-written paper. Lots ‘o’ Yoneda and laws. I’m not doing it justice. Check it out!

A while back I noted that reverse mode auto-differentiation has a type very similar to a lens and in fact you can build a working reverse mode automatic differentiation DSL out of lenses and lens-combinators. Many things about lenses, but not all, transfer over to automatic differentiation. The techniques of Matsuda and Wang do appear to transfer fairly well.

This is interesting to me for another reason. Their lift2 and unlift2 functions remind me very much of my recent approach to compiling to categories. The lift2 function is fanning a lens pair. This is basically what my FanOutput typeclass automated. unlift2 is building the input for a function function by supplying a tuple of projection lenses. This is what my BuildInput typeclass did. I think their style may extend many monoidal cartesian categories, not just lenses.

lift2 :: Lens (a,b) c -> (forall s. Num s => (Lens s a, Lens s b) -> Lens s c)
lift2 l (x,y) = lift l (fan x y)

unlift2 :: (Num a, Num b) => (forall s. Num s => (Lens s a, Lens s b) -> Lens s c) -> Lens (a,b) c
unlift2 f = f (fst', snd')

One can use the function b -> a in many of the situations one can use a in. You can do elegant things by making a Num typeclass of b -> a for example. This little fact seems to extend to other categories as well. By making a Num typeclass for Lens s a when a is a Num, we can use reasonable looking notation for arithmetic.

t1 :: Num a => Lens (a,a) a
t1 = unlift2 $ \(x,y) -> x + y*y + y * 7

They spend some time discussing the necessity of a Poset typeclass. For actual lawful lenses, the dup implementation needs a way to recombine multiple adjustments to the same object. In the AD-lens case, dup takes care of this by adding together the differentials. This means that everywhere they needed an Eq typeclass, we can use a Num typeclass. There may be usefulness to building a wrapped type data NZ a = Zero | NonZero a like their Tag type to accelerate the many 0 values that may be propagating through the system.

Unfortunately, as is, the performance of this is abysmal. Maybe there is a way to fix it? Unlifting and lifting destroys a lot of sharing and often necessitates adding many redundant zeros. Why are you doing reverse mode differentiation unless you care about performance? Forward mode is simpler to implement. In the intended use case of Matsuda and Wang, they are working with actual lawful lenses, which have far less computational content than AD-lenses. Good lawful lenses should just be shuffling stuff around a little. Maybe one can hope GHC is insanely intelligent and can optimize these zeros away. One point in favor of that is that our differentiation is completely pure (no mutation). Nevertheless, I suspect it will not without help. Being careful and unlifting and lifting manually may also help. In principle, I think the Lensy approach could be pretty fast (since all it is is just packing together exactly what you need to differentiate into a data type), but how to make it fast while still being easily programmable? It is also nice that it is pretty simple to implement. It is the simplest method that I know of if you needed to port operable reverse mode differentiation to a new library (Massiv?) or another functional language (Futhark?). And a smart compiler really does have a shot at finding optimizations/fusions.

While I was at it, unrelated to the paper above, I think I made a working generic auto differentiable fold lens combinator. Pretty cool. mapAccumL is a hot tip.

For practical Haskell purposes, all of this is a little silly with the good Haskell AD packages around, the most prominent being

http://hackage.haskell.org/package/ad

It is somewhat interesting to note the similarity of type forall s. Lens s appearing in the Matsuda and Wang approach to those those of the forall s. BVar s monad appearing in the backprop package. In this case I believe that the s type variable plays the same role it does in the ST monad, protecting a mutating Wengert tape state held in the monad, but I haven’t dug much into it. I don’t know enough about backprop to know what to make of this similarity.

http://hackage.haskell.org/package/backprop

The github repo with my playing around and stream of consciousness commentary is here

{-# LANGUAGE NoImplicitPrelude, TypeSynonymInstances, RankNTypes #-}
module Numeric.ADLens.AppBi where
-- import Numeric.ADLens.Lens
import Control.Category
import Prelude hiding (id, (.))
import Control.Arrow ((***))
import Data.Functor.Const
import Data.Traversable
newtype Lens x y = Lens (x -> (y, y -> x)) 
type L s a = Num s => Lens s a

instance Category Lens where
  id = Lens (\x -> (x, id))
  (Lens f) . (Lens g) = Lens $ \x -> let (y, g') = g x in
                                           let (z, f') = f y in
                                           (z, g' . f') 



grad'' (Lens f) x = let (y,j) = (f x) in j 1

lift :: Lens a b -> (forall s. Lens s a -> Lens s b)
lift l l' = l . l'

unlift :: Num a => (forall s. Num s => Lens s a -> Lens s b) -> Lens a b
unlift f = f id


dup :: Num a => Lens a (a,a)
dup = Lens $ \x -> ((x,x), \(dx,dy) -> dx + dy)

par :: Lens a b -> Lens c d -> Lens (a,c) (b,d)
par (Lens f) (Lens g) = Lens l'' where
    l'' (a,c) = ((b,d), f' *** g') where
        (b,f') = f a
        (d,g') = g c

fan :: Num s => Lens s a -> Lens s b -> Lens s (a,b)
fan x y = (par x y) . dup 

-- impredicative polymorphism errror when we use L in type sig. I'm just going to avoid that.
lift2 :: Lens (a,b) c -> (forall s. Num s => (Lens s a, Lens s b) -> Lens s c)
lift2 l (x,y) = lift l (fan x y)

unlift2 :: (Num a, Num b) => (forall s. Num s => (Lens s a, Lens s b) -> Lens s c) -> Lens (a,b) c
unlift2 f = f (fst', snd')

instance (Num a, Num b) => Num (a,b) where
	(x,y) + (a,b) = (x + a, y + b)
	(x,y) * (a,b) = (x * a, y * b)
	abs (x,y) = abs (x,y)
	fromInteger x = (fromInteger x, fromInteger x)
	-- and so on

fst' :: Num b => Lens (a,b) a
fst' = Lens (\(a,b) -> (a, \ds -> (ds, 0)))

snd' :: Num a => Lens (a,b) b
snd' = Lens (\(a,b) -> (b, \ds -> (0, ds)))

unit :: Num s => Lens s () -- ? This isn't right.
unit = Lens (\s -> ((), const 0))

add :: Num a => Lens (a,a) a 
add = Lens $ \(x,y) -> (x + y, \ds -> (ds, ds))

sub :: Num a => Lens (a,a) a 
sub = Lens $ \(x,y) -> (x - y, \ds -> (ds, -ds))

mul :: Num a => Lens (a,a) a 
mul = Lens $ \(x,y) -> (x * y, \dz -> (dz * y, x * dz))

recip' :: Fractional a => Lens a a 
recip' = Lens $ \x-> (recip x, \ds -> - ds / (x * x))

div :: Fractional a => Lens (a,a) a 
div = Lens $ (\(x,y) -> (x / y, \d -> (d/y,-x*d/(y * y))))

-- They called this "new" Section 3.2
constLens :: Num s => a -> Lens s a
constLens x = Lens (const (x, const 0))

-- or rather we might define add = unlift2 (+)
instance (Num s, Num a) => Num (Lens s a) where
	x + y = (lift2 add) (x,y)
	x - y = (lift2 sub) (x,y)
	x * y = (lift2 mul) (x,y)
	abs = error "TODO"
	fromInteger x = constLens (fromInteger x)

t1 :: Num a => Lens (a,a) a
t1 = unlift2 $ \(x,y) -> x + y*y + y * 7

-- See section on lifting list functions form biapplicative paper
-- These are could be Iso.
lcons :: Lens (a,[a]) [a]
lcons =  Lens $ \(a,as) -> (a : as, \xs -> (head xs, tail xs))
lnil :: Lens () [b]
lnil = Lens $ const ([], const ())

lsequence :: Num s => [Lens s a] -> Lens s [a]
lsequence [] = lift lnil unit
lsequence (x : xs) = lift2 lcons (x, lsequence xs)

llift :: Num s => Lens [a] b -> [Lens s a] -> Lens s b
llift l xs = lift l (lsequence xs)


instance (Num a) => Num [a] where
	(+) = zipWith (+)
	(*) = zipWith (*)
	(-) = zipWith (-)
	abs = map abs
	fromInteger x = repeat (fromInteger x)

-- We need to hand f a list of the accessor lenses
-- [Lens [a] a]
-- This feels quite wrong. Indexing into a list is naughty.
-- But that is what they do. Shrug.
lunlift :: Num a => (forall s. Num s => [Lens s a] -> Lens s b) -> Lens [a] b
lunlift f = Lens $ \xs -> 
					let n = length xs in
					let inds = [0 .. n-1] in
					let ls = map (lproj n) inds in
					let (Lens f') = f ls in
					f' xs

t2 :: Num a => Lens [a] a					
t2 = lunlift sum
t3 :: Num a => Lens [a] a					
t3 = lunlift product

lproj :: Num a => Int -> Int -> Lens [a] a
lproj n' ind = Lens $ \xs -> ((xs !! ind), \x' -> replace ind x' zeros) where
	replace 0 x (y:ys) = x : ys
	replace n x (y:ys) = y : (replace (n-1) x ys)
	zeros = replicate n' 0

	
lensmap :: Applicative f => Lens a b -> Lens (f a) (f b)
lensmap (Lens f) = Lens $ \fa ->
								let fbb = fmap f fa in
								let fb = fmap fst fbb in
								let fb2s = fmap snd fbb in
								(fb, \fb' -> fb2s <*> fb')

-- Types work, but does this actually make sense?
lsequenceA :: (Applicative f, Applicative t, Traversable f, Traversable t) => Lens (t (f a)) (f (t a))
lsequenceA = Lens $ \tfa -> (sequenceA tfa, sequenceA)

ltraverse :: (Applicative f, Applicative t, Traversable f, Traversable t) =>
             Lens a (f b) -> Lens (t a) (f (t b))
ltraverse f = lsequenceA . (lensmap f)

lensfoldl :: Traversable t => Lens (a, b) a -> Lens (a, t b) a
lensfoldl (Lens f) = Lens $ \(s, t) -> let (y, tape) = mapAccumL (curry f) s t  in
						  (y,  \db ->  mapAccumR (\db' f -> (f db')) db tape)
lensfoldr :: Traversable t => Lens (a, b) a -> Lens (a, t b) a
lensfoldr (Lens f) = Lens $ \(s, t) -> let (y, tape) = mapAccumR (curry f) s t  in
						(y,  \db ->  mapAccumL (\db' f -> (f db')) db tape)						  

t5 = grad'' (lensfoldl mul) (1, [1,1,2,3])


liftC :: Num a => (Lens a b -> Lens c d) -> (forall s. Num s => Lens s a -> Lens s b) -> (forall t. Num t => Lens t c -> Lens t d)
liftC c f = lift (c (unlift f))

ungrad :: Lens (a,b) c -> (a -> Lens b c)
ungrad (Lens f) a = Lens (\b -> let (c,j) = f (a,b) in (c, snd . j))

Reverse Mode Differentiation is Kind of Like a Lens II

For those looking for more on automatic differentiation in Haskell:

Ed Kmett’s ad package

http://hackage.haskell.org/package/ad

Conal Elliott is making the rounds with a new take on AD (GOOD STUFF).

http://conal.net/papers/essence-of-ad/

Justin Le has been making excellent posts and has another library he’s working on.

https://blog.jle.im/entry/introducing-the-backprop-library.html

 

And here we go:

Reverse mode automatic differentiation is kind of like a lens. Here is the type for a non-fancy lens

type Lens s t a b = s -> (a, b -> t)

When you compose two lenses, you compose the getters (s -> a) and you compose the partially applied setter (b -> t) in the reverse direction.

We can define a type for a reverse mode differentiable function

type AD x dx y dy = x -> (y, dy -> dx)

When you compose two differentiable functions you compose the functions and you flip compose the Jacobian transpose (dy -> dx). It is this flip composition which gives reverse mode it’s name. The dependence of the Jacobian on the base point x corresponds to the dependence of the setter on the original object

The implementation of composition for Lens and AD are identical.

Both of these things are described by the same box diagram (cribbed from the profunctor optics paper www.cs.ox.ac.uk/people/jeremy.gibbons/publications/poptics.pdf ).

 

 

This is a very simple way of implementing a reserve mode automatic differentiation using only non exotic features of a functional programming language. Since it is so bare bones and functional, is this a good way to achieve the vision gorgeous post by Christoper Olah?  http://colah.github.io/posts/2015-09-NN-Types-FP/  I do not know.

Now, to be clear, these ARE NOT lenses. Please, I don’t want to cloud the water, do not call these lenses. They’re pseudolenses or something. A very important part of what makes a lens a lens is that it obeys the lens laws, in which the getter and setter behave as one would expect. Our “setter” is a functional representation of the Jacobian transpose and our getter is the function itself. These do not obey lens laws in general.

Chain Rule AND Jacobian

What is reverse mode differentiation? One’s thinking is muddled by defaulting to the Calc I perspective of one dimensional functions. Thinking is also muddled by  the general conception that the gradient is a vector. This is slightly sloppy talk and can lead to confusion. It definitely has confused me.

The right setting for intuition is R^n \rightarrow R^m functions

If one looks at a multidimensional to multidimensional function like this, you can form a matrix of partial derivatives known as the Jacobian. In the scalar to scalar case this is a 1\times 1 matrix, which we can think of as just a number. In the multi to scalar case this is a 1\times n matrix which we somewhat fuzzily can think of as a vector.

The chain rule is a beautiful thing. It is what makes differentiation so elegant and tractable.

For many-to-many functions, if you compose them you matrix multiply their Jacobians.

Just to throw in some category theory spice (who can resist), the chain rule is a functor between the category of differentiable functions and the category of vector spaces where composition is given by Jacobian multiplication. This is probably wholly unhelpful.

The cost of multiplication for an a \times b matrix A and an b \times c matrix B is O(abc) . If we have 3 matrices ABC, we can associate to the left or right. (AB)C vs A(BC) choosing which product to form first. These two associations have different cost, abc * acd for left association or abd * bcd for right association. We want to use the smallest dimension over and over. For functions that are ultimately many to scalar functions, that means we want to multiply starting at the right.

For a clearer explanation of the importance of the association, maybe this will help https://en.wikipedia.org/wiki/Matrix_chain_multiplication

 

Functional representations of matrices

A Matrix data type typically gives you full inspection of the elements. If you partially apply the matrix vector product function (!* :: Matrix -> Vector -> Vector) to a matrix m, you get a vector to vector function (!* m) :: Vector -> Vector. In the sense that a matrix is data representing a linear map, this type looks gorgeous. It is so evocative of purpose.

If all you want to do is multiply matrices or perform matrix vector products this is not a bad way to go. A function in Haskell is a thing that exposes only a single interface, the ability to be applied. Very often, the loss of Gaussian elimination or eigenvalue decompositions is quite painfully felt. For simple automatic differentiation, it isn’t so bad though.

You can inefficiently reconstitute a matrix from it’s functional form by applying it to a basis of vectors.

One weakness of the functional form is that the type does not constrain the function to actually act linearly on the vectors.

One big advantage of the functional form is that you can intermix different matrix types (sparse, low-rank, dense) with no friction, just so long as they all have some way of being applied to the same kind of vector. You can also use functions like (id :: a -> a) as the identity matrix, which are not built from any underlying matrix type at all.

To match the lens, we need to represent the Jacobian transpose as the function (dy -> dx) mapping differentials in the output space to differentials in the input space.

The Lens Trick

A lens is the combination of a getter (a function that grabs a piece out of a larger object) and a setter (a function that takes the object and a new piece and returns the object with that piece replaced).

The common form of lens used in Haskell doesn’t look like the above. It looks like this.

type Lens s t a b = forall f. Functor f => (a -> f b) -> (s -> f t)

This form has exactly the same content as the previous form (A non obvious fact. See the Profunctor Optics paper above. Magic neato polymorphism stuff), with the added functionality of being able to compose using the regular Haskell (.) operator.

I think a good case can be made to NOT use the lens trick (do as I say, not as I do). It obfuscates sharing and obfuscates your code to the compiler (I assume the compiler optimizations have less understanding of polymorphic functor types than it does of tuples and functions), meaning the compiler has less opportunity to help you out. But it is also pretty cool. So… I dunno. Edit:

/u/mstksg points out that compilers actually LOVE the van Laarhoven representation (the lens trick) because when f is finally specialized it is a newtype wrappers which have no runtime cost. Then the compiler can just chew the thing apart.

https://www.reddit.com/r/haskell/comments/9oc9dq/reverse_mode_differentiation_is_kind_of_like_a/

One thing that is extra scary about the fancy form is that it makes it less clear how much data is likely to be shared between the forward and backward pass. Another alternative to the lens that shows this is the following.

type AD x dx y dy = (x -> y, x -> dy -> dx)

This form is again the same in end result. However it cannot share computation and therefore isn’t the same performance wise. One nontrivial function that took me some head scratching is how to convert from the fancy lens directly to the regular lens without destroying sharing. I think this does it

unfancy :: Lens' a b -> (a -> (b, b -> a))
unfancy l = getCompose . l (\b -> Compose (b, id))

 

Some code

I have some small exploration of the concept in this git https://github.com/philzook58/ad-lens

Again, really check out Conal Elliott’s AD paper and enjoy the many, many apostrophes to follow.

Some basic definitions and transformations between fancy and non fancy lenses. Extracting the gradient is similar to the set function. Gradient assumes a many to one function and it applies it to 1.

import Data.Functor.Identity
import Data.Functor.Const
import Data.Functor.Compose

type Lens' a b = forall f. Functor f => (b -> f b) -> a -> f a

lens'' :: (a -> (b, b -> a)) -> Lens' a b
lens'' h g x = fmap j fb where
    (b, j) = h x
    fb = g b

over :: Lens' a b -> ((b -> b) -> a -> a)
over l f = runIdentity . l (Identity . f)

set :: Lens' a b -> a -> b -> a
set l = flip (\x -> (over l (const x)))

view :: Lens' a b -> a -> b
view l = getConst . l Const

unlens'' :: Lens' a b -> (a -> (b, b -> a))
unlens'' l = getCompose . l (\b -> Compose (b, id))

constlens :: Lens' (a,b) c -> b -> Lens' a c
constlens l b = lens'' $ \a -> let (c, df) = f (a,b) in
                             (c, fst . df) where 
                                           f = unlens'' l


grad :: Num b => Lens' a b -> a -> a
grad l = (flip (set l)) 1

Basic 1D functions and arrow/categorical combinators

-- add and dup are dual!

add' :: Num a => Lens' (a,a) a 
add' = lens'' $ \(x,y) -> (x + y, \ds -> (ds, ds))

dup' :: Num a => Lens' a (a,a)
dup' = lens'' $ \x -> ((x,x), \(dx,dy) -> dx + dy)

sub' :: Num a => Lens' (a,a) a 
sub' = lens'' $ \(x,y) -> (x - y, \ds -> (ds, -ds))

mul' :: Num a => Lens' (a,a) a 
mul' = lens'' $ \(x,y) -> (x * y, \dz -> (dz * y, x * dz))

recip' :: Fractional a => Lens' a a 
recip' = lens'' $ \x-> (recip x, \ds -> - ds / (x * x))

div' :: Fractional a => Lens' (a,a) a 
div' = lens'' $ (\(x,y) -> (x / y, \d -> (d/y,-x*d/(y * y))))

sin' :: Floating a => Lens' a a
sin' = lens'' $ \x -> (sin x, \dx -> dx * (cos x))

cos' :: Floating a => Lens' a a
cos' = lens'' $ \x -> (cos x, \dx -> -dx * (sin x))

pow' :: Num a => Integer -> Lens' a a
pow' n = lens'' $ \x -> (x ^ n, \dx -> (fromInteger n) * dx * x ^ (n-1)) 

--cmul :: Num a => a -> Lens' a a
--cmul c = lens (* c) (\x -> \dx -> c * dx)

exp' :: Floating a => Lens' a a
exp' = lens'' $ \x -> let ex = exp x in
                      (ex, \dx -> dx * ex)

fst' :: Num b => Lens' (a,b) a
fst' = lens'' (\(a,b) -> (a, \ds -> (ds, 0)))

snd' :: Num a => Lens' (a,b) b
snd' = lens'' (\(a,b) -> (b, \ds -> (0, ds)))

swap' :: Lens' (a,b) (b,a)
swap' = lens'' (\(a,b) -> ((b,a), \(db,da) -> (da, db)))

assoc' :: Lens' ((a,b),c) (a,(b,c))
assoc' = lens'' $ \((a,b),c) -> ((a,(b,c)), \(da,(db,dc)) -> ((da,db),dc))

par' :: Lens' a b -> Lens' c d -> Lens' (a,c) (b,d)
par' l1 l2 = lens'' f3 where
    f1 = unlens'' l1
    f2 = unlens'' l2
    f3 (a,c) = ((b,d), df1 *** df2) where
        (b,df1) = f1 a
        (d,df2) = f2 c

fan' :: Num a => Lens' a b -> Lens' a c -> Lens' a (b,c)
fan' l1 l2 = lens'' f3 where
    f1 = unlens'' l1
    f2 = unlens'' l2
    f3 a = ((b,c), \(db,dc) -> df1 db + df2 dc) where
        (b,df1) = f1 a
        (c,df2) = f2 a

first' :: Lens' a b -> Lens' (a, c) (b, c)
first' l = par' l id

second' :: Lens' a b -> Lens' (c, a) (c, b)
second' l = par' id l

relu' :: (Ord a, Num a) => Lens' a a
relu' = lens'' $ \x -> (frelu x, brelu x) where
        frelu x | x > 0 = x
                | otherwise = 0
        brelu x dy | x > 0 = dy
                   | otherwise = 0

Some List based stuff.

import Data.List (sort)
import Control.Applicative (ZipList (..))

-- replicate and sum are dual!

sum' :: Num a => Lens' [a] a
sum' = lens'' $ \xs -> (sum xs, \dy -> replicate (length xs) dy)

replicate' :: Num a => Int -> Lens' a [a]
replicate' n = lens'' $ \x -> (replicate n x, sum)

repeat' :: Num a => Lens' a [a]
repeat' = lens'' $ \x -> (repeat x, sum)

map' :: Lens' a b -> Lens' [a] [b]
map' l = lens'' $ \xs -> let (bs, fs) = unzip . map (unlens'' l) $ xs in 
                       (bs, getZipList . ((ZipList fs) <*>) . ZipList)

zip' :: Lens' ([a], [b]) [(a,b)]
zip' = lens'' $ \(as,bs) -> (zip as bs, unzip)

unzip' :: Lens' [(a,b)] ([a], [b])
unzip' = lens'' $ \xs -> (unzip xs, uncurry zip)

maximum' :: (Num a, Ord a) => Lens' [a] a
maximum' = lens'' $ \(x:xs) -> let (best, bestind, lenxs) = argmaxixum x 0 1 xs in
                               (best, \dy -> onehot bestind lenxs dy) where
    argmaxixum best bestind len [] = (best, bestind, len) 
    argmaxixum best bestind curind (x:xs) = if x > best then argmaxixum x curind (curind + 1) xs else argmaxixum best bestind (curind + 1) xs  
    onehot n m x | m == 0 = []
                 | n == m = x : (onehot n (m-1) x) 
                 | otherwise = 0 : (onehot n (m-1) x)

sort' :: Ord a => Lens' [a] [a]
sort' = lens'' $ \xs -> let (sxs, indices) = unzip . sort $ zip xs [0 ..] in
                        (sxs, desort indices) where
                          desort indices = snd . unzip . sort . zip indices

And some functionality from HMatrix

import Numeric.LinearAlgebra
import Numeric.LinearAlgebra.Devel (zipVectorWith)
import Numeric.ADLens.Lens
-- import Data.Vector as V

dot' :: (Container Vector t, Numeric t) => Lens' (Vector t, Vector t) t
dot' = lens'' $ \(v1,v2) -> (v1 <.> v2, \ds -> (scale ds v2, scale ds v1))

mdot' :: (Product t, Numeric t) => Lens' (Matrix t, Vector t) (Vector t)
mdot' = lens'' $ \(a,v) -> (a #> v, \dv -> (outer dv v, dv <# a))

add' :: Additive c => Lens' (c, c) c
add' = lens'' $ \(v1,v2) -> (add v1 v2, \dv -> (dv, dv))

-- I need konst I think?
sumElements' :: (Container Vector t, Numeric t) => Lens' (Vector t) t
sumElements' = lens'' $ \v -> (sumElements v, \ds -> scalar ds)

reshape' :: Container Vector t => Int -> Lens' (Vector t) (Matrix t)
reshape' n = lens'' $ \v -> (reshape n v,  \dm -> flatten dm)

-- conjugate transpose not trace
tr'' ::  (Transposable m mt, Transposable mt m) => Lens' m mt
tr'' = lens'' $ \x -> (tr x, \dt -> tr dt)


flatten' :: (Num t, Container Vector t) => Lens' (Matrix t) (Vector t)
flatten' = lens'' $ \m -> let s = fst $ size m in  
                          (flatten m,  \dm -> reshape s dm)


norm_2' :: (Container c R, Normed (c R), Linear R c) => Lens' (c R) R
norm_2' = lens'' $ \v -> let nv = norm_2 v in (nv, \dnv -> scale (2 * dnv / nv) v )



cmap' :: (Element b, Container Vector e) => (Lens' e b) -> Lens' (Vector e) (Vector b)
cmap' l = lens'' $ \c -> (cmap f c, \dc -> zipVectorWith f' c dc) where
        f = view l
        f' = set l
 
{-
maxElement' :: Container c e => Lens' (c e) e
maxElement' = lens'' $ \v -> let i = maxIndex v in (v ! i, dv -> scalar 0)
-}

det' :: Field t => Lens' (Matrix t) t
det' = lens'' $ \m -> let (minv, (lndet, phase)) = invlndet m in
                    let detm = phase * exp detm in
                    (detm, \ds -> (scale (ds * detm) minv))

diag' :: (Num a, Element a) => Lens' (Vector a) (Matrix a)
diag' = lens'' $ \v -> (diag v, takeDiag)

takeDiag' :: (Num a, Element a) => Lens' (Matrix a) (Vector a) 
takeDiag' = lens'' $ \m -> (takeDiag m, diag)

In practice, I don’t think this is a very ergonomic approach without something like Conal Elliott’s Compiling to Categories plugin. You have to program in a point-free arrow style (inspired very directly by Conal’s above AD paper) which is pretty nasty IMO. The neural network code here is inscrutable. It is only a three layer neural network.

import Numeric.ADLens.Lens
import Numeric.ADLens.Basic
import Numeric.ADLens.List
import Numeric.ADLens.HMatrix


import Numeric.LinearAlgebra

type L1 = Matrix Double
type L2 = Matrix Double
type L3 = Matrix Double



type Input = Vector Double
type Output = Vector Double
type Weights = (L1,(L2,(L3,())))

class TupleSum a where
	tupsum :: a -> a -> a
instance TupleSum () where
	tupsum _ _ = ()
instance (Num a, TupleSum b) => TupleSum (a,b) where
	tupsum (a,x) (b,y) = (a + b, tupsum x y)

-- A dense relu neural network example
swaplayer :: Lens' ((Matrix t, b), Vector t) (b, (Matrix t, Vector t))
swaplayer = first' swap' . assoc' 

mmultlayer :: Numeric t => Lens' (b, (Matrix t, Vector t)) (b, Vector t)
mmultlayer = second' mdot'

relulayer :: Lens' (b, Vector Double) (b, Vector Double)
relulayer = second' $ cmap' relu'

uselayer :: Lens' ((Matrix Double, b), Vector Double) (b, Vector Double)
uselayer = swaplayer . mmultlayer . relulayer

runNetwork :: Lens' (Weights, Input) ((), Output)
runNetwork =  uselayer . uselayer . uselayer

main :: IO ()
main = do
   putStrLn "Starting Tests"
   print $ grad (pow' 2) 1
   print $ grad (pow' 4) 1
   print $ grad (map' (pow' 2) . sum') $ [1 .. 5]
   print $ grad (map' (pow' 4) . sum') $ [1 .. 5]
   print $ map (\x -> 4 * x ^ 3 )  [1 .. 5]
   l1 <- randn 3 4
   l2 <- randn 2 3
   l3 <- randn 1 2
   let weights = (l1,(l2,(l3,())))
   print $ view runNetwork (weights, vector [1,2,3,4])
   putStrLn "The neural network gradients"
   print $ set runNetwork (weights, vector [1,2,3,4]) ((), vector [1])