-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapt implementation of ONNX operators to code from OnnxToFeld. #511
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,29 +194,29 @@ bcZipWith f xs ys = zipWith f (uniBCast ext xs) (uniBCast ext ys) | |
where ext = unionExt (extent xs) (extent ys) | ||
|
||
-- | Implementation of ONNX tensor addition | ||
onnxAdd :: (Num a, ShapelyU sh1 sh2) | ||
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a | ||
onnxAdd _ = bcAdd | ||
onnxAdd_2 :: (Num a, ShapelyU sh1 sh2) | ||
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a | ||
onnxAdd_2 _ = bcAdd | ||
|
||
-- | Implementation of ONNX tensor subtraction | ||
onnxSub :: (Num a, ShapelyU sh1 sh2) | ||
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a | ||
onnxSub _ = bcSub | ||
onnxSub_2 :: (Num a, ShapelyU sh1 sh2) | ||
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a | ||
onnxSub_2 _ = bcSub | ||
|
||
-- | Implementation of ONNX tensor multiplication | ||
onnxMul :: (Num a, ShapelyU sh1 sh2) | ||
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a | ||
onnxMul _ = bcMul | ||
onnxMul_2 :: (Num a, ShapelyU sh1 sh2) | ||
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a | ||
onnxMul_2 _ = bcMul | ||
|
||
-- | Implementation of ONNX tensor fractional division | ||
onnxDivF :: (Fractional a, ShapelyU sh1 sh2) | ||
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a | ||
onnxDivF _ = bcDivF | ||
onnxDivF_2 :: (Fractional a, ShapelyU sh1 sh2) | ||
=> Attrs -> Pull sh1 a -> Pull sh2 a -> Pull (UnionShape sh1 sh2) a | ||
onnxDivF_2 _ = bcDivF | ||
|
||
-- | Implementation of ONNX tensor integral division | ||
onnxDivI :: (Integral a, ShapelyU sh1 sh2) | ||
=> Attrs -> DPull sh1 a -> DPull sh2 a -> DPull (UnionShape sh1 sh2) a | ||
onnxDivI _ = bcDivI | ||
onnxDivI_2 :: (Integral a, ShapelyU sh1 sh2) | ||
=> Attrs -> DPull sh1 a -> DPull sh2 a -> DPull (UnionShape sh1 sh2) a | ||
onnxDivI_2 _ = bcDivI | ||
|
||
-- | Elementwise add with broadcasting | ||
bcAdd :: (Num a, ShapelyU sh1 sh2) | ||
|
@@ -244,23 +244,23 @@ bcDivI :: (Integral a, ShapelyU sh1 sh2) | |
bcDivI = bcZipWith div -- Or quot??? | ||
|
||
-- | Implementation of ONNX batch normalization | ||
onnxBatchNormalization :: Floating a | ||
=> Attrs -- ^ attributes (including epsilon) | ||
-> DPull DIM4 a -- ^ data | ||
-> DPull DIM1 a -- ^ gamma | ||
-> DPull DIM1 a -- ^ beta | ||
-> DPull DIM1 a -- ^ mean | ||
-> DPull DIM1 a -- ^ var | ||
-> DPull DIM4 a | ||
onnxBatchNormalization attrs xs gamma beta mean var = ys | ||
onnxBatchNormalization_5 :: Floating a | ||
=> Attrs -- ^ attributes (including epsilon) | ||
-> DPull DIM4 a -- ^ data | ||
-> DPull DIM1 a -- ^ gamma | ||
-> DPull DIM1 a -- ^ beta | ||
-> DPull DIM1 a -- ^ mean | ||
-> DPull DIM1 a -- ^ var | ||
-> DPull DIM4 a | ||
onnxBatchNormalization_5 attrs xs gamma beta mean var = ys | ||
where invDev = map (\ v -> 1.0 / sqrt (v + epsilon)) var <! 1 <! 1 | ||
xsHat = bcMul invDev $ bcSub xs $ mean <! 1 <! 1 | ||
ys = bcAdd (bcMul xsHat $ gamma <! 1 <! 1) $ beta <! 1 <! 1 | ||
epsilon = value $ P.realToFrac $ getAttr attrs aaFloat 1e-5 "epsilon" | ||
|
||
-- | Flatten a tensor to a matrix | ||
onnxFlatten :: Pushy vec => Attrs -> vec a -> Push DIM2 a | ||
onnxFlatten attrs vec = flatPush (P.fromIntegral $ getAttr attrs aaInt 1 "axis") $ toPush vec | ||
onnxFlatten_1 :: (Pushy vec, Syntax a) => Attrs -> vec a -> Pull DIM2 a | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Everything seems to be Pully in this module except this stuff. Why do we need to store a reshape (index transformation) to memory? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could in principle be done with a Pull vector and consequently be fuseable. The problem is that the index expressions would be horrible, using division and modulus since the result has fewer dimensions than the argument so that a single integer index must be split into several. We have no optimization in place that eliminates them, so they will typically end up in the innermost loop. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be great to have that kind of information in a comment in the source code. |
||
onnxFlatten_1 attrs vec = toPull $ store $ flatPush (P.fromIntegral $ getAttr attrs aaInt 1 "axis") $ toPush vec | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. |
||
|
||
-- | Flattening a Push vector to two dimensions | ||
flatPush :: forall sh a . Int -> Push sh a -> Push DIM2 a | ||
|
@@ -284,9 +284,9 @@ takeDropShape i sh = P.splitAt j es | |
j = if i P.< 0 then i + P.length es else i | ||
|
||
-- | Matrix multiplication of two dimensional temsors | ||
onnxGemm3 :: (RealFloat a, Numeric a, Pully vec, VecShape vec ~ DIM2) | ||
=> Attrs -> vec (Data a) -> vec (Data a) -> vec (Data a) -> DPull DIM2 a | ||
onnxGemm3 attrs vA vB vC = bcZipWith (+) (mmT vAT vBnT) $ toPull vC | ||
onnxGemm_3 :: (RealFloat a, Numeric a, Pully vec, VecShape vec ~ DIM2, Pully vec2, UnionShape DIM2 (VecShape vec2) ~ DIM2) | ||
=> Attrs -> vec (Data a) -> vec (Data a) -> vec2 (Data a) -> DPull DIM2 a | ||
onnxGemm_3 attrs vA vB vC = bcZipWith (+) (mmT vAT vBnT) $ toPull vC | ||
where vA' = if alpha P.== 1.0 then toPull vA else map (* value alpha) $ toPull vA | ||
vAT = if transA P.== 1 then transpose vA' else vA' | ||
vB' = toPull vB | ||
|
@@ -314,10 +314,17 @@ infixl 5 <! | |
(<!) :: Pull sh a -> Data Length -> Pull (sh :. Data Length) a | ||
Pull ixf ext <! n = Pull (\ (ix :. _) -> ixf ix) (ext :. n) | ||
|
||
-- | Implementation of ONNX convolution operator | ||
onnxConv :: (Num a, Syntax a) | ||
=> Attrs -> Pull DIM4 a -> Pull DIM4 a -> Pull DIM1 a -> Pull DIM4 a | ||
onnxConv attrs xs = onnxConvNP (value $ map fromInteger strides) pXs | ||
-- | Implementation of ONNX convolution operator for 2 inputs | ||
onnxConv_2 :: (Num a, Syntax a) | ||
=> Attrs -> Pull DIM4 a -> Pull DIM4 a -> Pull DIM4 a | ||
onnxConv_2 attrs xs ws = onnxConv_3 attrs xs ws bs | ||
where bs = Pull (const 0) (Z :. m) | ||
Z :. m :. _ :. _ :. _ = extent ws | ||
|
||
-- | Implementation of ONNX convolution operator for 3 inputs | ||
onnxConv_3 :: (Num a, Syntax a) | ||
=> Attrs -> Pull DIM4 a -> Pull DIM4 a -> Pull DIM1 a -> Pull DIM4 a | ||
onnxConv_3 attrs xs = onnxConvNP (value $ map fromInteger strides) pXs | ||
where -- dilations = getAttr attrs aaInts [1, 1] "dilations" -- Currently unused | ||
-- group = getAttr attrs aaInt 1 "group" -- Currently unused | ||
-- kernel_shape = getAttrM attrs aaInts "kernel_shape" -- Currently unused | ||
|
@@ -342,9 +349,17 @@ onnxConvNP ss xs ws bs = Pull ixf (Z :. nLen :. mLen :. h1 :. w1) `bcAdd` (bs <! | |
w1 = (w - kW) `div` sX + 1 | ||
|
||
-- | Implementation of ONNX global average pooling | ||
onnxGlobalAveragePool :: Fraction a => Attrs -> DPull DIM4 a -> DPull DIM4 a | ||
onnxGlobalAveragePool _ = vvmap dim2 avgF | ||
where avgF vec = map (/ (i2n $ size $ extent vec)) (sum $ sum vec) <! 1 <! 1 | ||
onnxGlobalAveragePool_1 :: Fraction a => Attrs -> DPull DIM4 a -> DPull DIM4 a | ||
onnxGlobalAveragePool_1 _ = vvmap dim2 avgF | ||
where avgF vec = map (/ (i2n $ size $ extent vec)) (sum $ sum vec) <! 1 <! 1 | ||
|
||
-- | Implementation of ONNX Relu | ||
onnxRelu_1 :: (Numeric a, Ord a) => Attrs -> DPull sh a -> DPull sh a | ||
onnxRelu_1 _ = fmap (max 0) | ||
|
||
-- | Implementation of ONNX MaxPool | ||
onnxMaxPool_1 :: (Numeric a, Ord a) => Attrs -> DPull DIM4 a -> DPull DIM4 a | ||
onnxMaxPool_1 _ xs = xs -- TODO! | ||
|
||
-- | Padding a multi dimensional vector | ||
pad :: forall a vec sh . (Syntax a, Num a, Pushy vec, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if the handwritten code is readable and these are binary operators, do we really need the
_2
on them? Same question for the_5
ononnxBatchNormalization
,_3
ononnxGemm
,_1
ononnxFlatten
, and so on.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arity information is necessary because some ONNX operators can be used with different numbers of arguments. To avoid it, onnxToFeld needs to know about those operators and treat them as special cases, something I have tried to avoid, at least for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we get a silent failure if we get the special cases wrong? If not, I would prefer an Operators.hs without extra suffixes and two additional lines (
| op elem varargOps = ""
and| otherwise = "_" <> show (length args)
) in OnnxToFeld.hs.