Tensors
This post is part of a series of notes on machine learning.
This post is literate Haskell; you can load the source into GHCi and play along.
First some boilerplate.
{-# LANGUAGE LambdaCase #-}
module Tensors where
import Data.Array
import Data.Foldable
import Control.Applicative
import qualified Text.PrettyPrint as PP
import Test.QuickCheck
import Indices
import IndexIsos
Earlier, we defined two algebras whose elements represent the possible sizes of multidimensional arrays and possible indices into multidimensional arrays, respectively. We did this in such a way that the possible indices into an array with (vector space) dimension \(k\) can be mapped to \(\{0,1, \ldots, k-1\}\) in a canonical way. With this in hand, we can define a tensor of size \(s \in \mathbb{S}\) as a mapping from the indices of \(s\) to \(\mathbb{R}\). And thanks to the canonical mapping to integers, we can implement our tensors in memory using a linear array. In math notation, we will identify each \(s \in \mathbb{S}\) with its indices, and think of tensors as elements of \(\mathbb{R}^s\) (that is, functions from indices to real numbers).
We’ll say that two tensors are strictly equal, denoted $==
, if they have the same sizes and the same entries at each index.
(Strict equality is too, well, strict. We’ll nail down the real equality on tensors in a moment.)
A tensor “is” a map from indices to \(\mathbb{R}\)s. The tensor
function lets us build a tensor by supplying this map.
tensor :: Size -> (Index -> r) -> Tensor r
tensor s f = T s (array (0,(dimOf s)-1) entries)
where
entries = [(flatten s t, f t) | t <- indicesOf s]
To retrieve the entry of a tensor at a given index, we evaluate the tensor as a function. We’ll call this at
. So in math notation, we’d write \(\mathsf{at}(A,i) = A(i)\) or \(A_i\).
We’re actually going to define two slightly different versions of at
. The first works only on nonzero sizes, but for all entry types. The second treats the size zero vector as if it has entry 0 at every possible index, but of course only makes sense for numeric entry types. (Looking ahead, there’s a good reason for doing this, having to do with dual numbers and automatic differentiation.)
at' :: Tensor r -> Index -> r
at' (T s a) t = if t `isIndexOf` s
then a ! (flatten s t)
else error $ "at: incompatible index " ++ show t
++ " for size " ++ show s
at :: (Num r) => Tensor r -> Index -> r
at a t =
if (size a) ~= 0
then 0
else a `at'` t
So tensor
and at
obey the following identities:
We’ll also define some helper functions to make building tensors more convenient. For instance, a uniform tensor has the same value at each index.
uniform :: Size -> r -> Tensor r
uniform s x = tensor s (\_ -> x)
ones, zeros :: (Num r) => Size -> Tensor r
ones s = uniform s 1
zeros s = uniform s 0
We’ll use the notation \(\mathsf{Zero}_s\) to denote the zero tensor of size \(s\).
We can use at
and the canonical isomorphism on index sets to define equality for tensors.
instance (Eq r) => Eq (Tensor r) where
a@(T u _) == b@(T v _) = if u ~= v
then all (\i -> (a`at'`i) == (b`at'`(mapIndex u v i))) (indicesOf u)
else False
We’ll see the reason for this weak equality in a bit. But for now, note that the following two tensors are equal, but not strictly equal.
More generally, strict equality implies equality, but not vice versa.
The simplest possible (nontrivial) tensor has size 1; we will call these cells.
We’ll also provide a simple way to construct vectors and matrices with natural number size.
vec :: [r] -> Tensor r
vec xs = tensor k (\(Index i) -> xs !! (fromIntegral i))
where
k = Size $ fromIntegral $ length xs
mat :: [[r]] -> Tensor r
mat [] = tensor 0 (const undefined)
mat [[]] = tensor 0 (const undefined)
mat xss = tensor (r :* c) $
\((Index i) :& (Index j))
-> (xss !! (fromIntegral i)) !! (fromIntegral j)
where
r = Size $ fromIntegral $ length xss
c = Size $ fromIntegral $ length $ head xss
The downside of defining our tensors recursively is that it’s less clear what the index of a given entry is. To help out with this, we’ll define two helpers: indexOf
, that defines a tensor of a given size whose entries are equal to their indices, and orderOf
, that shows how the entries of a tensor are linearized internally.
indexOf :: Size -> Tensor Index
indexOf s = tensor s id
orderOf :: Size -> Tensor Integer
orderOf s = tensor s (flatten s)
This works because we can pass tensor
any function on indices. For example, here are three different views of a size \(3 \otimes 3\) tensor.
$> ones (3*3)
1 1 1
1 1 1
1 1 1
$> indexOf (3*3)
(0,0) (0,1) (0,2)
(1,0) (1,1) (1,2)
(2,0) (2,1) (2,2)
$> orderOf (3*3)
0 3 6
1 4 7
2 5 8
Try using indexOf
on more complicated sizes.
Tensor as a Functor
One of the first questions we ask about type constructors is whether they are naturally members of any interesting classes. It’s not too surprising that Tensor
is a functor, where fmap
is “pointwise” function application.
To verify the functor laws, we make sure that fmap id == id
. (Remember that $==
means strict equality.)
and that fmap (g . f) == fmap g . fmap f
.
fmap g (fmap f a)
$== fmap g (fmap f a@(T u _))
$== fmap g (tensor u (\i -> f (a`at`i)))
$== tensor u (\i -> g ((tensor u (\j -> f (a`at`j))) `at` i))
$== tensor u (\i -> g (f (a`at`i)))
$== tensor u (\i -> (g . f) (a`at`i))
$== fmap (g . f) a
We can also define a Foldable
instance for tensors, using the canonical order on indices.
From here we can immediately take the sum
and maximum
of a tensor. We’ll also define a kind of zip
for tensors of equivalent size; I had trouble finding a good general class for zippable functors in the libraries.
tzip :: Tensor a -> Tensor b -> Tensor (a,b)
tzip a@(T u _) b@(T v _) = if u ~= v
then tensor u (\i -> (a`at'`i, b`at'`(mapIndex u v i)))
else error "zip: tensors must have equivalent size"
tzipWith :: (a -> b -> c) -> Tensor a -> Tensor b -> Tensor c
tzipWith f a b = fmap (uncurry f) $ tzip a b
Tensor
is also applicative. (Making this work is the main motivation for defining equality the way we did.)
instance Applicative Tensor where
pure = cell
a@(T u _) <*> b@(T v _) = tensor (u :* v) $
\(i :& j) -> (a `at'` i) (b `at'` j)
We need to see that this implementation satisfies the applicative laws. First the identity law:
pure id <*> a == a
for all a
pure id <*> a
$== cell id <*> a@(T u _)
$== (tensor 1 (const id)) <*> a@(T u _)
$== tensor (1 :* u) (\(i :& j) -> id (a`at`j))
$== tensor (1 :* u) (\(i :& j) -> a`at`j)
== tensor u (\j -> a`at`j)
$== a
Next we establish the composition law:
pure (.) <*> a <*> b <*> c == a <*> (b <*> c)
.
pure (.) <*> a <*> b <*> c
$== tensor 1 (const (.)) <*> a@(T u _) <*> b <*> c
$== tensor (1 :* u) (\(i :& j) -> (.) (a`at`j))
<*> b@(T v _) <*> c
$== tensor ((1 :* u) :* v)
(\((i :& j) :& k) -> (a`at`j) . (b`at`k))
<*> c@(T w _)
$== tensor (((1 :* u) :* v) :* w)
(\(((i :& j) :& k) :& l) ->
(a`at`j) . (b`at`k) $ (c`at`l))
== tensor (u :* (v :* w))
(\(j :& (k :& l)) -> (a`at`j) $ (b`at`k) (c`at`l))
$== a <*> tensor (v :* w) (\(k :& l) -> (b`at`k) (c`at`l))
$== a <*> (b <*> c)
The homomorphism law:
pure f <*> pure x == pure (f x)
pure f <*> pure x
$== tensor 1 (const f) <*> tensor 1 (const x)
$== tensor (1 :* 1) (\(i :& j) -> f x)
== tensor 1 (\_ -> f x)
$== pure (f x)
And the interchange law:
a <*> pure x = pure ($ x) <*> a
a <*> pure x
$== a@(T u _) <*> tensor 1 (const x)
$== tensor (u :* 1) (\(i :& j) -> (a`at`i) x)
== tensor (1 :* u) (\(j :& i) -> ($ x) (a`at`i))
$== tensor 1 (const ($ x)) <*> a@(T u _)
$== pure ($ x) <*> a
It may seem like overkill to go to the trouble of defining equality the way we did just to make Tensor
an applicative functor, and it is – we won’t need the applicativeness much. But there’s a payoff: the outer product of tensors is defined in terms of <*>
.
While we’re at it, Tensor
is also Alternative
.
instance Alternative Tensor where
empty = tensor 0 (\_ -> undefined)
a@(T u _) <|> b@(T v _) = tensor (u :+ v) $
\case
L i -> a `at'` i
R j -> b `at'` j
We should also verify the Alternative
laws. First the monoid laws that everyone agrees Alternatives
should satisfy. Left identity:
empty <|> a == a
empty <|> a
$== tensor 0 (const undefined) <|> a@(T u _)
$== tensor (0 :+ u) (\case L i -> undefined; R j -> a`at`j)
== tensor u (\j -> a`at`j)
$== a
Right identity:
a <|> empty == a
a <|> empty
$== a@(T u _) <|> tensor 0 (const undefined)
$== tensor (u :+ 0) (\case L i -> a`at`i; R j -> undefined)
== tensor u (\i -> a`at`i)
$== a
Associativity:
(a <|> b) <|> c == a <|> (b <|> c)
(a <|> b) <|> c
$== (a@(T u _) <|> b@(T v _)) <|> c
$== (tensor (u :+ v)
(\case
L i -> a`at`i
R j -> b`at`j
)) <|> c@(T w _)
$== tensor ((u :+ v) :+ w)
(\case
L l -> case l of
L i -> a`at`i
R j -> b`at`j
R k -> c`at`k)
$== tensor ((u :+ v) :+ w)
(\case
L (L i) -> a`at`i
L (R j) -> b`at`j
R k -> c`at`k)
== tensor (u :+ (v :+ w))
(\case
L i -> a`at`i
R (L j) -> b`at`j
R (R k) -> c`at`k)
$== tensor (u :+ (v :+ w))
(\case
L i -> a`at`i
R l -> case l of
L j -> b`at`j
R k -> c`at`k))
$== a <|> tensor (v :+ w)
(\case
L j -> b`at`j
R m -> c`at`m)
$== a <|> (b <|> c)
And some of the laws that only hold for some Applicative
instances (including this one). Left zero:
empty <*> a == empty
empty <*> a
$== tensor 0 (const undefined) <*> a@(T u _)
$== tensor (0 :* u) (\(i :& j) -> undefined (a `at` j))
== tensor 0 (\_ -> undefined)
$== empty
Right zero:
a <*> empty == empty
a <*> empty
$== a@(T u _) <*> tensor 0 (const undefined)
$== tensor (u :* 0) (\(i :& j) -> (a`at`i) undefined)
== tensor 0 (\_ -> undefined)
$== empty
Vector Arithmetic
Tensors are vectors, so they should have the usual vector operations of plus, negate, and scale. Other vector spaces will show up later, so we’ll define these operations with a class.
class Vector t where
(.@) :: (Num r) => r -> t r -> t r
(.+) :: (Num r) => t r -> t r -> t r
neg :: (Num r) => t r -> t r
(.-) :: (Num r) => t r -> t r -> t r
a .- b = a .+ (neg b)
vsum :: (Num r) => [t r] -> t r
vsum = foldr1 (.+)
instance Vector Tensor where
r .@ a = fmap (r*) a
a .+ b
| size a ~= 0 = b
| size b ~= 0 = a
| otherwise = tzipWith (+) a b
neg = fmap negate
The Hadamard or entrywise product is also handy. While we’re at it, entrywise quotients.
(.*) :: (Num r) => Tensor r -> Tensor r -> Tensor r
(.*) = tzipWith (*)
(./) :: (Num r, Fractional r) => Tensor r -> Tensor r -> Tensor r
(./) = tzipWith (/)
Thinking of tensors as vectors, we can dot them together in the usual way.
In math notation, if \(A,B \in \mathbb{R}^s\), \[\mathsf{dot}(A,B) = \sum_{i \in s} A_i B_i.\] The ‘dot square’ of a tensor will also be handy later.
We also have some tensor-centric operations. First is oplus
, which constructs a tensor with sum shape.
In a rough and handwavy way, if \(a \in \mathbb{R}^u\) and \(b \in \mathbb{R}^v\), then \[a \oplus b \in \mathbb{R}^u \oplus \mathbb{R}^v \cong \mathbb{R}^{u \oplus v},\] and \(\oplus\) is the operator that achieves this isomorphism.
This function otimes
is called the dyadic or outer product.
Structural Arithmetic
Now we’ll define some structural operators on tensors; these are functions that manipulate the size of a tensor, or combine tensors into more complicated ones, or extract subparts. These are mostly based on extract
, which defines a new tensor in terms of an existing one.
extract :: Size -> (Index -> Index) -> Tensor r -> Tensor r
extract u f a = tensor u (\i -> a `at'` (f i))
For example, we can extract “terms” from a summand tensor using extract
like so.
termL, termR :: Tensor r -> Tensor r
termL a@(T (u :+ _) _) = extract u L a
termL _ = error "termL: argument must have sum shape"
termR a@(T (_ :+ v) _) = extract v R a
termR _ = error "termR: argument must have sum shape"
In math notation we have \(\mathsf{termL} : \mathbb{R}^{s \oplus t} \rightarrow \mathbb{R}^s\) given by \(\mathsf{termL}(A)_i = A_{\mathsf{l}(i)}\), and \(\mathsf{termR}\) is similar.
Next we have projection operators, which take a tensor in \(\mathbb{R}^{s \otimes t}\) and fix one of the index components. In the usual matrix language, projection would extract one row or one column of a matrix. There are two of these, with the following signature.
projR, projL :: Index -> Tensor r -> Tensor r
projR i a@(T (u :* v) _) = if (i `isIndexOf` u)
then extract v (i :&) a
else error "projR: index and size not compatible."
projR _ _ = error "projR: tensor argument must have product shape"
projL j a@(T (u :* v) _) = if (j `isIndexOf` v)
then extract u (:& j) a
else error "projL: index and size not compatible."
projL _ _ = error "projL: tensor argument must have product shape"
In math notation we have \(\mathsf{projR} : s \rightarrow \mathbb{R}^{t \otimes s} \rightarrow \mathbb{R}^t\) given by \(\mathsf{projL}(i,A)_j = A_{i \& j}\), and \(\mathsf{projL}\) is similar.
Now \(\mathbb{R}^{u \otimes v}\) and \(\mathbb{R}^{v \otimes u}\) are not equal, but they are canonically isomorphic; likewise \(\mathbb{R}^{u \oplus v}\) and \(\mathbb{R}^{v \oplus u}\). comm
achieves this.
comm :: Tensor r -> Tensor r
comm a@(T (u :* v) _) =
extract (v :* u) f a
where
f (j :& i) = (i :& j)
comm a@(T (u :+ v) _) =
extract (v :+ u) (opIndex PlusComm) a
comm _ = error "comm: wrong shape"
Similarly, \(\mathbb{R}^{u \otimes (v \otimes w)}\) and \(\mathbb{R}^{(u \otimes v) \otimes w}\) are canonically isomorphic, and likewise for \(\oplus\).
assocL, assocR :: Tensor r -> Tensor r
assocL a@(T (u :* (v :* w)) _) =
extract ((u :* v) :* w) (opIndex TimesAssocR) a
assocL a@(T (u :+ (v :+ w)) _) =
extract ((u :+ v) :+ w) (opIndex PlusAssocR) a
assocL _ = error "assocL: argument has wrong shape"
assocR a@(T ((u :* v) :* w) _) =
extract (u :* (v :* w)) (opIndex TimesAssocL) a
assocR a@(T ((u :+ v) :+ w) _) =
extract (u :+ (v :+ w)) (opIndex PlusAssocL) a
assocR _ = error "assocR: argument has wrong shape"
We also have \[\mathbb{R}^{(a \otimes b) \oplus (a \otimes c)} \cong \mathbb{R}^{a \otimes b} \times \mathbb{R}^{a \otimes c}.\] We’ll define a couple of operators to canonically “undistribute” \(\otimes\) over \(\oplus\).
vcat, (~-~) :: Tensor r -> Tensor r -> Tensor r
vcat a@(T (u :* h) _) b@(T (v :* k) _) =
if h == k
then extract ((u :+ v) :* k) (opIndex DistR) (oplus a b)
else error "vcat: size mismatch"
vcat a b = error $ "vcat: sizes with wrong shape: " ++ show (size a)
++ " and " ++ show (size b)
(~-~) = vcat
hcat, (~|~) :: Tensor r -> Tensor r -> Tensor r
hcat a@(T (h :* u) _) b@(T (k :* v) _) =
if h == k
then extract (k :* (u :+ v)) (opIndex DistL) (oplus a b)
else error "hcat: size mismatch"
(~|~) = hcat
We give vcat
and hcat
symbolic synonyms, meant to evoke what they do on matrices. vcat
concatenates matrices vertically, and hcat
concatenates them horizontally.
Matrix Operations
Now for a couple of matrix-specific operations. First the identity matrix.
where kronecker
representes the Kronecker delta function \[\delta_{i,j} = \left\{ \begin{array}{ll} 1 & \mathrm{if}\ i = j \\ 0 & \mathrm{otherwise}. \end{array} \right.\]
And we can “diagonalize” any tensor.
diag :: (Num r) => Tensor r -> Tensor r
diag a@(T u _) = tensor (u :* u) $
\(i :& j) -> (kronecker i j) * (a`at`i)
The tensor generalization of matrix multiplication is sometimes called contraction. We’ll mostly be interested in plain matrix multiplication. We’ll define it as a matrix-matrix operation, a matrix-vector operation, and a vector-matrix operation using slightly different symbols. Surely this won’t come back to haunt us.
(***) :: (Num r) => Tensor r -> Tensor r -> Tensor r
a@(T (m :* n) _) *** b@(T (u :* v) _) =
if u == n
then tensor (m*v)
(\ (i :& j) -> sum
[ (a`at`(i :& k))*(b`at`(k :& j)) | k <- indicesOf n ])
else error "(***): inner sizes must match."
_ *** _ = error "(***): expected mat/mat."
(**>) :: (Num r) => Tensor r -> Tensor r -> Tensor r
a@(T (m :* n) _) **> b@(T u _) =
if u == n
then tensor m
(\i -> sum
[ (a`at`(i :& k))*(b`at`k) | k <- indicesOf n ])
else error "(**>): inner sizes must match."
_ **> _ = error "(**>): expected mat/vec."
(<**) :: (Num r) => Tensor r -> Tensor r -> Tensor r
a@(T u _) <** b@(T (n :* m) _) =
if u == n
then tensor m
(\i -> sum
[ (a`at`k)*(b`at`(k :& i)) | k <- indicesOf n ])
else error "(<**): inner sizes must match."
_ <** _ = error "(<**): expected vec/mat."
Pretty Printing
We’ll end this post with the Show
instance for tensors; we’ll build it on top of the pretty printing combinator library by John Hughes and Simon Peyton Jones. (The original paper on that library is a nice case study in DSL design.)
First we convert a tensor of strings to a Doc
(in the pretty printer parlance), doing more or less the obvious thing.
toDoc :: Tensor String -> PP.Doc
toDoc a@(T s _) =
case s of
Size k -> PP.hsep $ map PP.text [ a`at'`i | i <- indicesOf s ]
u :+ v -> (toDoc $ termL a) PP.$$ (toDoc $ termR a)
u :* v -> PP.vcat [
PP.hsep [
PP.text $ a `at'` (i :& j)
| j <- indicesOf v ]
| i <- indicesOf u ]
To actually show the tensor, we show the entries (pointwise) and pad to the maximum entry width (so the cells line up), then show the corresponding Doc
.
instance (Show r) => Show (Tensor r) where
show a =
let
cellWidth = maximum $ fmap (length . show) a
m = fmap (padLeft cellWidth . show) a
in
PP.render $ toDoc m
where
-- left-pad a string with spaces to a given length
padLeft :: Int -> String -> String
padLeft k = reverse . take k . (++ (repeat ' ')) . reverse
This method for displaying tensors is not perfect, but it has the advantage of being simple and doing mostly the right thing in the most common cases of \(k\) and \(m \otimes n\) tensors (i.e. vectors and matrices). Apropos of nothing: further support for this method is that tensors with shape \(k_1 \oplus k_2 \oplus \cdots \oplus k_n\) look like Young tableaux.
Tests
In future posts we’ll be writing tests involving tensors, so I’ll put an Arbitrary
instance here.
instance (Arbitrary r) => Arbitrary (Tensor r) where
arbitrary = arbitrary >>= (arbTensorOf undefined)
shrink a@(T u _) = case u of
Size k ->
if k <= 0
then []
else
[ tensor (Size $ k-1) (\i -> a`at'`i)
, uniform (Size $ k-1) (a`at'`0)
]
_ :+ _ -> concat
[ [ h <|> k | h <- shrink $ termL a, k <- shrink $ termR a ]
, [ termL a, termR a ]
]
_ -> []
arbTensorOf :: (Arbitrary r) => r -> Size -> Gen (Tensor r)
arbTensorOf _ s = do
as <- vectorOf (fromIntegral $ dimOf s) arbitrary
return $ tensor s (\i -> as !! (fromIntegral $ flatten s i))
arbBinaryTensorOf :: (Arbitrary r, Num r) => r -> Size -> Gen (Tensor r)
arbBinaryTensorOf _ s = do
as <- vectorOf (fromIntegral $ dimOf s) $ elements [0,1]
return $ tensor s (\i -> as !! (fromIntegral $ flatten s i))