Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt implementation of ONNX operators to code from OnnxToFeld. #511

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 51 additions & 36 deletions src/Feldspar/Onnx/Operators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -194,29 +194,29 @@ bcZipWith f xs ys = zipWith f (uniBCast ext xs) (uniBCast ext ys)
where ext = unionExt (extent xs) (extent ys)

-- | Implementation of ONNX tensor addition
onnxAdd :: (Num a, ShapelyU sh1 sh2)
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a
onnxAdd _ = bcAdd
onnxAdd_2 :: (Num a, ShapelyU sh1 sh2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if the handwritten code is readable and these are binary operators, do we really need the _2 on them? Same question for the _5 on onnxBatchNormalization, _3 on onnxGemm, _1 on onnxFlatten, and so on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arity information is necessary because some ONNX operators can be used with different numbers of arguments. To avoid it, onnxToFeld needs to know about those operators and treat them as special cases, something I have tried to avoid, at least for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we get a silent failure if we get the special cases wrong? If not, I would prefer an Operators.hs without extra suffixes and two additional lines (| op elem varargOps = "" and | otherwise = "_" <> show (length args)) in OnnxToFeld.hs.

=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a
onnxAdd_2 _ = bcAdd

-- | Implementation of ONNX tensor subtraction
onnxSub :: (Num a, ShapelyU sh1 sh2)
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a
onnxSub _ = bcSub
onnxSub_2 :: (Num a, ShapelyU sh1 sh2)
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a
onnxSub_2 _ = bcSub

-- | Implementation of ONNX tensor multiplication
onnxMul :: (Num a, ShapelyU sh1 sh2)
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a
onnxMul _ = bcMul
onnxMul_2 :: (Num a, ShapelyU sh1 sh2)
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a
onnxMul_2 _ = bcMul

-- | Implementation of ONNX tensor fractional division
onnxDivF :: (Fractional a, ShapelyU sh1 sh2)
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a
onnxDivF _ = bcDivF
onnxDivF_2 :: (Fractional a, ShapelyU sh1 sh2)
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a
onnxDivF_2 _ = bcDivF

-- | Implementation of ONNX tensor integral division
onnxDivI :: (Integral a, ShapelyU sh1 sh2)
=> Attrs -> DPull sh1 a -> DPull sh2 a -> DPull (UnionShape sh1 sh2) a
onnxDivI _ = bcDivI
onnxDivI_2 :: (Integral a, ShapelyU sh1 sh2)
=> Attrs -> DPull sh1 a -> DPull sh2 a -> DPull (UnionShape sh1 sh2) a
onnxDivI_2 _ = bcDivI

-- | Elementwise add with broadcasting
bcAdd :: (Num a, ShapelyU sh1 sh2)
Expand Down Expand Up @@ -244,23 +244,23 @@ bcDivI :: (Integral a, ShapelyU sh1 sh2)
bcDivI = bcZipWith div -- Or quot???

-- | Implementation of ONNX batch normalization
onnxBatchNormalization :: Floating a
=> Attrs -- ^ attributes (including epsilon)
-> DPull DIM4 a -- ^ data
-> DPull DIM1 a -- ^ gamma
-> DPull DIM1 a -- ^ beta
-> DPull DIM1 a -- ^ mean
-> DPull DIM1 a -- ^ var
-> DPull DIM4 a
onnxBatchNormalization attrs xs gamma beta mean var = ys
onnxBatchNormalization_5 :: Floating a
=> Attrs -- ^ attributes (including epsilon)
-> DPull DIM4 a -- ^ data
-> DPull DIM1 a -- ^ gamma
-> DPull DIM1 a -- ^ beta
-> DPull DIM1 a -- ^ mean
-> DPull DIM1 a -- ^ var
-> DPull DIM4 a
onnxBatchNormalization_5 attrs xs gamma beta mean var = ys
where invDev = map (\ v -> 1.0 / sqrt (v + epsilon)) var <! 1 <! 1
xsHat = bcMul invDev $ bcSub xs $ mean <! 1 <! 1
ys = bcAdd (bcMul xsHat $ gamma <! 1 <! 1) $ beta <! 1 <! 1
epsilon = value $ P.realToFrac $ getAttr attrs aaFloat 1e-5 "epsilon"

-- | Flatten a tensor to a matrix
onnxFlatten :: Pushy vec => Attrs -> vec a -> Push DIM2 a
onnxFlatten attrs vec = flatPush (P.fromIntegral $ getAttr attrs aaInt 1 "axis") $ toPush vec
onnxFlatten_1 :: (Pushy vec, Syntax a) => Attrs -> vec a -> Pull DIM2 a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything seems to be Pully in this module except this stuff. Why do we need to store a reshape (index transformation) to memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could in principle be done with a Pull vector and consequently be fuseable. The problem is that the index expressions would be horrible, using division and modulus since the result has fewer dimensions than the argument so that a single integer index must be split into several. We have no optimization in place that eliminates them, so they will typically end up in the innermost loop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to have that kind of information in a comment in the source code.

onnxFlatten_1 attrs vec = toPull $ store $ flatPush (P.fromIntegral $ getAttr attrs aaInt 1 "axis") $ toPush vec
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is idxExp something from class Ix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No.


-- | Flattening a Push vector to two dimensions
flatPush :: forall sh a . Int -> Push sh a -> Push DIM2 a
Expand All @@ -284,9 +284,9 @@ takeDropShape i sh = P.splitAt j es
j = if i P.< 0 then i + P.length es else i

-- | Matrix multiplication of two dimensional temsors
onnxGemm3 :: (RealFloat a, Numeric a, Pully vec, VecShape vec ~ DIM2)
=> Attrs -> vec (Data a) -> vec (Data a) -> vec (Data a) -> DPull DIM2 a
onnxGemm3 attrs vA vB vC = bcZipWith (+) (mmT vAT vBnT) $ toPull vC
onnxGemm_3 :: (RealFloat a, Numeric a, Pully vec, VecShape vec ~ DIM2, Pully vec2, UnionShape DIM2 (VecShape vec2) ~ DIM2)
=> Attrs -> vec (Data a) -> vec (Data a) -> vec2 (Data a) -> DPull DIM2 a
onnxGemm_3 attrs vA vB vC = bcZipWith (+) (mmT vAT vBnT) $ toPull vC
where vA' = if alpha P.== 1.0 then toPull vA else map (* value alpha) $ toPull vA
vAT = if transA P.== 1 then transpose vA' else vA'
vB' = toPull vB
Expand Down Expand Up @@ -314,10 +314,17 @@ infixl 5 <!
(<!) :: Pull sh a -> Data Length -> Pull (sh :. Data Length) a
Pull ixf ext <! n = Pull (\ (ix :. _) -> ixf ix) (ext :. n)

-- | Implementation of ONNX convolution operator
onnxConv :: (Num a, Syntax a)
=> Attrs -> Pull DIM4 a -> Pull DIM4 a -> Pull DIM1 a -> Pull DIM4 a
onnxConv attrs xs = onnxConvNP (value $ map fromInteger strides) pXs
-- | Implementation of ONNX convolution operator for 2 inputs
onnxConv_2 :: (Num a, Syntax a)
=> Attrs -> Pull DIM4 a -> Pull DIM4 a -> Pull DIM4 a
onnxConv_2 attrs xs ws = onnxConv_3 attrs xs ws bs
where bs = Pull (const 0) (Z :. m)
Z :. m :. _ :. _ :. _ = extent ws

-- | Implementation of ONNX convolution operator for 3 inputs
onnxConv_3 :: (Num a, Syntax a)
=> Attrs -> Pull DIM4 a -> Pull DIM4 a -> Pull DIM1 a -> Pull DIM4 a
onnxConv_3 attrs xs = onnxConvNP (value $ map fromInteger strides) pXs
where -- dilations = getAttr attrs aaInts [1, 1] "dilations" -- Currently unused
-- group = getAttr attrs aaInt 1 "group" -- Currently unused
-- kernel_shape = getAttrM attrs aaInts "kernel_shape" -- Currently unused
Expand All @@ -342,9 +349,17 @@ onnxConvNP ss xs ws bs = Pull ixf (Z :. nLen :. mLen :. h1 :. w1) `bcAdd` (bs <!
w1 = (w - kW) `div` sX + 1

-- | Implementation of ONNX global average pooling
onnxGlobalAveragePool :: Fraction a => Attrs -> DPull DIM4 a -> DPull DIM4 a
onnxGlobalAveragePool _ = vvmap dim2 avgF
where avgF vec = map (/ (i2n $ size $ extent vec)) (sum $ sum vec) <! 1 <! 1
onnxGlobalAveragePool_1 :: Fraction a => Attrs -> DPull DIM4 a -> DPull DIM4 a
onnxGlobalAveragePool_1 _ = vvmap dim2 avgF
where avgF vec = map (/ (i2n $ size $ extent vec)) (sum $ sum vec) <! 1 <! 1

-- | Implementation of ONNX Relu
onnxRelu_1 :: (Numeric a, Ord a) => Attrs -> DPull sh a -> DPull sh a
onnxRelu_1 _ = fmap (max 0)

-- | Implementation of ONNX MaxPool
onnxMaxPool_1 :: (Numeric a, Ord a) => Attrs -> DPull DIM4 a -> DPull DIM4 a
onnxMaxPool_1 _ xs = xs -- TODO!

-- | Padding a multi dimensional vector
pad :: forall a vec sh . (Syntax a, Num a, Pushy vec,
Expand Down