Typeclasses, Automatic Differentiation and more

The following code typechecks..

						quickSort :: [Int] -> [Int]
						quickSort []  = []
						quickSort (x:xs)  = l ++ [x] ++ r
						  where
						    (l, r) = partition (< x) xs
					
Does the following code typecheck?

						quickSort :: [a] -> [a]
						quickSort []  = []
						quickSort (x:xs)  = l ++ [x] ++ r
						  where
						    (l, r) = partition (< x) xs
					
Why?

Enter Typeclasses

						
							quickSort :: [a] -> [a]
							quickSort []  = []
							quickSort (x:xs)  = l ++ [x] ++ r
							  where
							    (l, r) = partition (< x) xs
						
  • a must be comparable
  • Int, Char, [Int] are comparable.
  • Int -> Int, a -> Bool are not comparable.
The types Int, Char, [Int] are instances of the typeclass Ord

							quickSort :: (Ord a) => [a] -> [a]
							quickSort []  = []
							quickSort (x:xs)  = l ++ [x] ++ r
							  where
							    (l, r) = partition (< x) xs
						
typechecks!
Let's create a type for Polynomials

							newtype Polynomial = Poly [Int]
						

							*Main> Poly [1,2,3]
							:10:1:
							    No instance for (Show (Polynomial a0))
							      arising from a use of ‘print’
							    In the first argument of ‘print’, namely ‘it’
							    In a stmt of an interactive GHCi command: print it
							
need to make Polynomial to a member of Show

							newtype Polynomial = Poly [Int]
								deriving Show
						

							*Main> Poly [1,2,3]
							 Poly [1,2, 3]
						

works now.. but not so great

Let's make define our own show

							instance Show Polynomial where
							  show (Poly [])       = "0"
							  show (Poly (x:xs))
							      = (show x) ++ " + " ++ (concat $ intersperse " + " terms)
							    where
							    terms
							      = map (\(a,b) -> (show a) ++ " x^" ++ (show b)) (zip xs [1..n])
							    n = length xs
						

							*Main Data.List> Poly [1..5]
							1 + 2 x^1 + 3 x^2 + 4 x^3 + 5 x^4
						
our pretty printing works now!
Let's Order them...

...make them a member of the typeclass Ord


							instance Ord Polynomial where
							  (Poly xs) `compare` (Poly ys)
							    | length xs < length ys   = LT
							    | length xs > length ys   = GT
							    | head xs < head ys       = LT
							    | head xs > head ys       = GT
							    | otherwise               = EQ
							
but that does not work...

									Prelude Data.List> :r
									[1 of 1] Compiling Main             ( randomStuff.hs, interpreted )

									randomStuff.hs:20:10:
									    No instance for (Eq Polynomial)
									      arising from the superclasses of an instance declaration
									    In the instance declaration for ‘Ord Polynomial’
									Failed, modules loaded: none.
								
  • The Ord class extends the class Eq.
  • The definition of the Ord typeclass looks something like this:
    
    									class  (Eq a) => Ord a  where
      									(<), (<=), (>=), (>)  :: a -> a -> Bool
      									max, min              :: a -> a -> a
    								
  • This means all the methods in Eq are available as a part of the Ord class.
  • We often say that Ord is a subclass of Eq.

						newtype Polynomial = Poly [Int]
						  deriving Eq
						
... and that fixes our code.
We want to do arithmetic on these polynomials....

Enter Num Typeclass


							instance Num Polynomial where
							  (Poly fs) + (Poly [])         = Poly fs
							  (Poly []) + (Poly gs)         = Poly gs
							  (Poly (f:fs)) + (Poly (g:gs)) = Poly (f+g : rest)
							    where (Poly rest) = (Poly fs) + (Poly gs)
							  (Poly (f:fs)) * (Poly (g:gs)) = Poly (f*g : rest)
							    where (Poly rest) = (Poly [f])*(Poly gs) + (Poly fs)*(Poly (g:gs))
							  (Poly _) * (Poly _) = Poly []
							  fromInteger n       = Poly [fromInteger n]
							  abs (Poly xs)       = undefined --length xs
							  signum (Poly xs)    = Poly (map signum xs)
							  negate (Poly xs)    = Poly (map (\x -> -x) xs)
						

							*Main> (Poly [1,4,5])+(Poly [1,2,1])
							2 + 6 x^1 + 6 x^2
						
addition works!

								*Main> (Poly [1,4,5])*(Poly [1,2,1])
								1 + 6 x^1 + 14 x^2 + 14 x^3 + 5 x^4
							
... so does multiplication

							*Main> (Poly [1,1])^5
							1 + 5 x^1 + 10 x^2 + 10 x^3 + 5 x^4 + 1 x^5
						
exponentiation comes for free!
  • The Num typeclass supports (^) but we do not need to define it manually.
  • We also have (-), as long as we have negate, and vice-versa.

							newtype Polynomial = Poly [Int]
							  deriving Eq
						
but why stick to Int polynomials?

							newtype Polynomial a = Poly [a]
								deriving Eq
						
we now have polynomials on arbitrary values

							instance Show (Polynomial a)

							instance Ord (Polynomial a)

							instance Num (Polynomial a)
						
But how do you show, compare, or add arbitrary types?!
want conditional instances instead...

							instance Show a => Show (Polynomial a)

							instance Ord a => Ord (Polynomial a)

							instance Num a => Num (Polynomial a)
						
A Word on Kinds

... type of types

  • Polynomial Int is really a type. We can produce values which inhabit Polynomial Int
    
    								*Main> :t (Poly [1,2,3]) :: Polynomial Int
    								(Poly [1,2,3]) :: Polynomial Int :: Polynomial Int
    							
  • Polynomial is not really a concrete type. It is a type constructor.
  • Concrete types are of the kind *
    
    									*Main> :k Polynomial Int
    									Polynomial Int :: *
    								
  • Type Constructors are of the kind * -> *
    
    									*Main> :k Polynomial
    									Polynomial :: * -> *
    								
  • Guess what the kind of (->) is?
    
    									*Main> :k (->)
    									(->) :: * -> * -> *
    								
  • Typeclasses take a type ( or often, a type constructor) and produce a Constraint kind
    
    									*Main> :k Ord
    									Ord :: * -> GHC.Prim.Constraint
    								
There are typeclasses which take in a type constructor...

meet Functor


								*Main> :k Functor
								Functor :: (* -> *) -> GHC.Prim.Constraint
							

The types over which you can fmap
  • Could be the usual lists
    
    									instance Functor [] where
    	  							fmap = map
    								
  • Or, say, binary trees
    
    									data Tree a = Leaf a | Fork (Tree a) (Tree a)
    
    									instance Functor Tree where
    										fmap f (Leaf x) = Leaf (f x)
    										fmap f (Fork l r) = Fork (fmap f l) (fmap f r)
    								
The Functor laws

							fmap id      ≡ id              -- identity law
							fmap (f . g) ≡ fmap f . fmap g -- composition law
						
Let's make Polynomial an instance of Functor too..

						instance Functor Polynomial where
						  fmap f (Poly xs)    = Poly (map f xs)
						
It works...

								*Main> fmap (+1) (Poly [1,2,3])
								2 + 3 x^1 + 4 x^2
							
Automatic Differentiation
Idea: Carry the tower of derivatives around

							data DX a = DX { val :: a, dx :: DX a }
						
The derivatives of the constants are zero

							instance Num n => Num (DX n) where
							  fromInteger x = DX (fromInteger x) 0
						
...notice the circularity in the definition here.
a couple more lines for the numeric operations..

							instance Num n => Num (DX n) where
							  fromInteger x = DX (fromInteger x) 0
							  DX x₀ x' + DX y₀ y' = DX (x₀ + y₀) (x' + y')
							  DX x₀ x' - DX y₀ y' = DX (x₀ - y₀) (x' - y')
							  x@(DX x₀ x') * y@(DX y₀ y') = DX (x₀ * y₀) (x * y' + y * x')
							  signum (DX x₀ x') = DX (signum x₀) 0
							  abs x@(DX x₀ x') = DX (abs x₀) (signum x * x')
						
and the quotient rule:

							instance Fractional n => Fractional (DX n) where
							  fromRational n = DX (fromRational n) 0
							  x@(DX x₀ x') / y@(DX y₀ y') =
							    DX (x₀ / y₀) ((x' * y - x * y') / y ^ 2)
						
to check equality on the values..

							instance Eq a => Eq (DX a) where
							  a == b = val a == val b
							instance Ord a => Ord (DX a) where
							  compare a b = compare (val a) (val b)
							
and see the first few derivatives:

							instance Show a => Show (DX a) where
								show (DX x (DX x' (DX x'' _))) = show [x, x', x'']
						
and the variable of differentiation

							var x = DX x 1
						
Let's find out the derivative of \(7 x ^2 + 3 x + 2\) at \(x = 5\)

							*Main> (\ x -> 7*x^2 + 3*x + 2) (var 5)
							[192,73,14]
						
Let's try differentiating this square root function

							mySqrt :: (Num a, Ord a, Fractional a) => a -> a -> a
							mySqrt eps x = go 1
							  where
							    go guess
							      | abs (guess^2 - x) < eps   = guess
							      | otherwise                 = go newGuess
							        where
							          newGuess  =   guess - (guess^2 - x)/(2*guess)
						
What is the value of \(\frac{d\sqrt{x}}{dx}\) at \(x = 2\)?

								*Main> mySqrt 0.001 (var 2)
								[1.4142156862745099,0.35356593617839294,-8.832952635110178e-2]
							
...which is nearly accurate with \(\frac{1}{2 \sqrt{x}}\)

								*Main> (1/(2*(sqrt 2)))
								0.35355339059327373
							
and I think I can get that precision by actually changing my eps

								*Main> mySqrt 0.0000000000001 (var 2)
								[1.4142135623730951,0.35355339059327373,-8.838834764831843e-2]
							
Newton's Method
\(x_{n+1} = x_{n} - \frac{f(x_n)}{f'(x_n)}\)

							newtons eps f guess
							  | abs (f guess) < eps   = guess
							  | otherwise             = newtons eps f newGuess
							    where
							      newGuess              = guess - (x/x')
							      (DX x (DX x' _))     = f (var guess)
						
What type is this?
First try..

							newtons :: (Num a, Fractional a, Ord a) => a -> (a -> a) -> a -> a
						
doesn't work...

								randomStuff.hs:75:34:
						    Couldn't match expected type ‘a’ with actual type ‘DX a’
						      ‘a’ is a rigid type variable bound by
						          the type signature for
						            newtons :: (Num a, Fractional a) => a -> (a -> a) -> a -> a
						          at randomStuff.hs:69:12
						    Relevant bindings include
						      guess :: a (bound at randomStuff.hs:70:15)
						      f :: a -> a (bound at randomStuff.hs:70:13)
						      eps :: a (bound at randomStuff.hs:70:9)
						      newtons :: a -> (a -> a) -> a -> a (bound at randomStuff.hs:70:1)
						    In the first argument of ‘f’, namely ‘(var guess)’
						    In the expression: f (var guess)
							
Second try..

							newtons :: (Num a, Ord a, Fractional a) => a -> (DX a -> DX a) -> a -> a
						
...doesn't work either

							newtons eps f guess
								| abs (f guess) < eps   = guess
								| otherwise             = newtons eps f newGuess
									where
										newGuess              = guess - (x/x')
										(DX x (DX x' _))     = f (var guess)
						
  • f should not only behave as a -> a but also DX a -> DX a
  • Idea: newtons is not just polymorphic, but it also wants a polymorphic function as an argument
  • Meet Rank2Types

							{-# LANGUAGE Rank2Types #-}
							newtons :: (Num a, Ord a, Fractional a) => a -> (forall b. (Num b, Ord b, Fractional b) => b -> b) -> a -> a
						
...typechecks!
and we can find \(\sqrt{2}\) without any hassle at all!

							*Main> newtons 0.0000001 (\x -> x^2 - 2) 1
							1.4142135623746899
						
Symbolic Differentiation
Let's get a symbolic type working first...

							data Op = Plus | Minus | Times | Divide deriving Eq

							data Sym a = Lit a
							          | Var String
							          | Expr Op (Sym a) (Sym a) deriving Eq
						
Overloading the operations is easy...

							instance Num a => Num (Sym a) where
							 fromInteger = Lit . fromInteger
							 (+) = Expr Plus
							 (-) = Expr Minus
							 (*) = Expr Times

							instance Fractional a => Fractional (Sym a) where
							 fromRational = Lit . fromRational
							 (/) = Expr Divide
						
We need some pretty printing...

							instance Show Op where
							  show Plus       = "+"
							  show Minus      = "-"
							  show Divide     = "/"
							  show Times      = "*"


							instance Show a => Show (Sym a) where
							  show (Var x)            = x
							  show (Lit n)            = show n
							  show (Expr op e1 e2)    = "(" ++ (show e1) ++ ")" ++ show op ++ "(" ++ (show e2) ++ ")"
						
What is the derivative of \(x^2 + 2 x + 7\)?

							*Main> val . dx $ (\x -> (x^2 + 2*x + 7)) (var (Var "x"))
							((((x)*(1))+((x)*(1)))+(((2)*(1))+((x)*(0))))+(0)
						

...really need some simplification!


							step :: (Eq a, Num a, Fractional a) => Sym a -> Sym a
							step (Expr op (Lit n₁) (Lit n₂)) =
							      Lit $ (case op of Plus -> (+); Minus -> (-); Times -> (*); Divide -> (/)) n₁ n₂
							step (Expr Plus
							        (Expr Times (Lit n₁) (Var x))
							        (Expr Times (Lit n₂) (Var y)))
							          | x == y = Expr Times (Expr Plus (Lit n₁) (Lit n₂)) (Var x)
							step (Expr Plus (Var x) (Var y))
							          | x == y = Expr Times (Lit 2) (Var x)
							step (Expr Plus 0 n)      = n
							step (Expr Plus n 0)      = n
							step (Expr Times 1 n)     = n
							step (Expr Times n 1)     = n
							step (Expr Times 0 n)     = 0
							step (Expr Times n 0)     = 0
							step (Expr op e₁ e₂)      = Expr op (step e₁) (step e₂)
							step (atom)               = atom

							simplify = until . iterate step
							  where until (x₁ : x₂ : xs) | x₁ == x₂  = x₁
							                             | otherwise = until (x₂ : xs)
						

							*Main> simplify . val . dx $ (\x -> (x^2 + 2*x + 7)) (var (Var "x"))
							((2.0)*(x))+(2.0)
						
not so bad now!
How about the derivative of \(\frac{1}{x}\)?

								*Main> simplify . val . dx $ (\x -> (1/x)) (var (Var "x"))
								(-1.0)/((x)*(x))
							

Thank You!