Skip to content

Commit

Permalink
Added generic FromField and ToField classes
Browse files Browse the repository at this point in the history
  • Loading branch information
zohl committed Oct 26, 2016
1 parent a8f6a90 commit c41b9f8
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 40 deletions.
3 changes: 2 additions & 1 deletion postgresql-simple.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ Library

if !impl(ghc >= 7.6)
Build-depends:
ghc-prim
ghc-prim,
tagged >= 0.8

extensions: DoAndIfThenElse, OverloadedStrings, BangPatterns, ViewPatterns
TypeOperators
Expand Down
103 changes: 98 additions & 5 deletions src/Database/PostgreSQL/Simple/FromField.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
{-# LANGUAGE FlexibleInstances, TypeSynonymInstances #-}
{-# LANGUAGE PatternGuards, ScopedTypeVariables #-}
{-# LANGUAGE RecordWildCards, TemplateHaskell #-}
{-# LANGUAGE MultiWayIf, DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}

{- |
Module: Database.PostgreSQL.Simple.FromField
Expand Down Expand Up @@ -83,6 +85,7 @@ instances use 'typename' instead.
module Database.PostgreSQL.Simple.FromField
(
FromField(..)
, genericFromField
, FieldParser
, Conversion()

Expand Down Expand Up @@ -113,16 +116,19 @@ module Database.PostgreSQL.Simple.FromField

#include "MachDeps.h"

import Control.Applicative ( (<|>), (<$>), pure, (*>), (<*) )
import Control.Applicative ( Alternative(..), (<|>), (<$>), pure, (*>), (<*), liftA2 )
import Control.Concurrent.MVar (MVar, newMVar)
import Control.Exception (Exception)
import qualified Data.Aeson as JSON
import qualified Data.Aeson.Parser as JSON (value')
import Data.Attoparsec.ByteString.Char8 hiding (Result)
import Data.ByteString (ByteString)
import Data.ByteString.Builder (Builder, toLazyByteString, byteString)
import qualified Data.ByteString.Char8 as B
import Data.Char (toLower)
import Data.Int (Int16, Int32, Int64)
import Data.IORef (IORef, newIORef)
import Data.Proxy (Proxy(..))
import Data.Ratio (Ratio)
import Data.Time ( UTCTime, ZonedTime, LocalTime, Day, TimeOfDay )
import Data.Typeable (Typeable, typeOf)
Expand Down Expand Up @@ -150,6 +156,7 @@ import qualified Data.CaseInsensitive as CI
import Data.UUID.Types (UUID)
import qualified Data.UUID.Types as UUID
import Data.Scientific (Scientific)
import GHC.Generics (Generic, Rep, M1(..), K1(..), D1, C1, S1, Rec0, Constructor, (:*:)(..), to, conName)
import GHC.Real (infinity, notANumber)

-- | Exception thrown if conversion from a SQL value to a Haskell
Expand Down Expand Up @@ -188,6 +195,8 @@ type FieldParser a = Field -> Maybe ByteString -> Conversion a
-- | A type that may be converted from a SQL type.
class FromField a where
fromField :: FieldParser a
default fromField :: (Generic a, Typeable a, GFromField (Rep a)) => FieldParser a
fromField = genericFromField (map toLower)
-- ^ Convert a SQL value to a Haskell value.
--
-- Returns a list of exceptions if the conversion fails. In the case of
Expand Down Expand Up @@ -292,7 +301,8 @@ instance FromField Null where
-- | bool
instance FromField Bool where
fromField f bs
| typeOid f /= $(inlineTypoid TI.bool) = returnError Incompatible f ""
| typeOid f /= $(inlineTypoid TI.bool)
&& typeOid f /= $(inlineTypoid TI.unknown) = returnError Incompatible f ""
| bs == Nothing = returnError UnexpectedNull f ""
| bs == Just "t" = pure True
| bs == Just "f" = pure False
Expand Down Expand Up @@ -404,9 +414,9 @@ instance FromField (Binary SB.ByteString) where
instance FromField (Binary LB.ByteString) where
fromField f dat = Binary . LB.fromChunks . (:[]) . unBinary <$> fromField f dat

-- | name, text, \"char\", bpchar, varchar
-- | name, text, \"char\", bpchar, varchar, unknown
instance FromField ST.Text where
fromField f = doFromField f okText $ (either left pure . ST.decodeUtf8')
fromField f = doFromField f okText' $ (either left pure . ST.decodeUtf8')
-- FIXME: check character encoding

-- | name, text, \"char\", bpchar, varchar
Expand Down Expand Up @@ -645,10 +655,93 @@ returnError mkErr f msg = do
atto :: forall a. (Typeable a)
=> Compat -> Parser a -> Field -> Maybe ByteString
-> Conversion a
atto types p0 f dat = doFromField f types (go p0) dat
atto types p0 f dat = doFromField f (\t -> types t || (t == $(inlineTypoid TI.unknown))) (go p0) dat
where
go :: Parser a -> ByteString -> Conversion a
go p s =
case parseOnly p s of
Left err -> returnError ConversionFailed f err
Right v -> pure v


-- | Type class for default implementation of FromField using generics.
class GFromField f where
gfromField :: (Typeable p)
=> Proxy p
-> (String -> String)
-> Field
-> [Maybe ByteString]
-> Conversion (f p)

instance (GFromField f) => GFromField (D1 i f) where
gfromField w t f v = M1 <$> gfromField w t f v

instance (GFromField f, Typeable f, Constructor i) => GFromField (C1 i f) where
gfromField w t f (v:[]) = let
tname = B8.pack . t . conName $ (undefined::(C1 i f t))
tcheck = (\t -> t /= "record" && t /= tname)
in tcheck <$> typename f >>= \b -> M1 <$> case b of
True -> returnError Incompatible f ""
False -> maybe
(returnError UnexpectedNull f "")
(either
(returnError ConversionFailed f)
(gfromField w t f)
. (parseOnly record)) v
gfromField _ _ f _ = M1 <$> returnError ConversionFailed f errUnexpectedArgs

instance (GFromField f, Typeable f, GFromField g) => GFromField (f :*: g) where
gfromField _ _ f [] = liftA2 (:*:) (returnError ConversionFailed f errTooFewValues) empty
gfromField w t f (v:vs) = liftA2 (:*:) (gfromField w t f [v]) (gfromField w t f vs)

instance (GFromField f, Typeable f) => GFromField (S1 i f) where
gfromField _ _ f [] = M1 <$> returnError ConversionFailed f errTooFewValues
gfromField w t f (v:[]) = M1 <$> gfromField w t f [v]
gfromField _ _ f _ = M1 <$> returnError ConversionFailed f errTooManyValues

instance (FromField f, Typeable f) => GFromField (Rec0 f) where
gfromField _ _ f [v] = K1 <$> fromField (f {typeOid = typoid TI.unknown}) v
gfromField _ _ f _ = K1 <$> returnError ConversionFailed f errUnexpectedArgs


-- | Common error messages for GFromField instances.
errTooFewValues, errTooManyValues, errUnexpectedArgs :: String
errTooFewValues = "too few values"
errTooManyValues = "too many values"
errUnexpectedArgs = "unexpected arguments"

-- | Parser of a postgresql record.
record :: Parser [Maybe ByteString]
record = (char '(') *> (recordField `sepBy` (char ',')) <* (char ')')

-- | Parser of a postgresql record's field.
recordField :: Parser (Maybe ByteString)
recordField = (Just <$> quotedString) <|> (Just <$> unquotedString) <|> (pure Nothing) where
quotedString = unescape <$> (char '"' *> scan False updateState) where
updateState isBalanced c = if
| c == '"' -> Just . not $ isBalanced
| not isBalanced -> Just False
| c == ',' || c == ')' -> Nothing
| otherwise -> fail $ "unexpected symbol: " ++ [c]

unescape = unescape' '\\' . unescape' '"' . B8.init where
unescape' c = halve c (byteString SB.empty) . groupByChar c

groupByChar c = B8.groupBy $ \a b -> (a == c) == (b == c)

halve :: Char -> Builder -> [ByteString] -> ByteString
halve _ b [] = LB.toStrict . toLazyByteString $ b
halve c b (s:ss) = halve c (b <> b') ss where
b' = if
| (/= c) . B8.head $ s -> byteString s
| otherwise -> byteString . SB.take ((SB.length s) `div` 2) $ s

unquotedString = takeWhile1 (\c -> c /= ',' && c /= ')')

-- | Function that creates fromField for a given type.
genericFromField :: forall a. (Generic a, Typeable a, GFromField (Rep a))
=> (String -> String) -- ^ How to transform constructor's name to match
-- postgresql type's name.
-> FieldParser a
genericFromField t f v = (to <$> (gfromField (Proxy :: Proxy a) t f [v]))

27 changes: 27 additions & 0 deletions src/Database/PostgreSQL/Simple/ToField.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE CPP, DeriveDataTypeable, DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances, TypeSynonymInstances #-}
{-# LANGUAGE DefaultSignatures, FlexibleContexts #-}

------------------------------------------------------------------------------
-- |
Expand Down Expand Up @@ -39,6 +40,7 @@ import Data.Word (Word, Word8, Word16, Word32, Word64)
import {-# SOURCE #-} Database.PostgreSQL.Simple.ToRow
import Database.PostgreSQL.Simple.Types
import Database.PostgreSQL.Simple.Compat (toByteString)
import GHC.Generics (Generic, Rep, D1, C1, S1, (:*:)(..), Rec0, from, unM1, unK1)

import qualified Data.ByteString as SB
import qualified Data.ByteString.Lazy as LB
Expand Down Expand Up @@ -92,6 +94,8 @@ instance Show Action where
-- | A type that may be used as a single parameter to a SQL query.
class ToField a where
toField :: a -> Action
default toField :: (Generic a, GToField (Rep a)) => a -> Action
toField = head . gtoField . from
-- ^ Prepare a value for substitution into a query string.

instance ToField Action where
Expand Down Expand Up @@ -369,3 +373,26 @@ instance ToRow a => ToField (Values a) where
(litC ',')
rest
vals

-- Type class for default implementation of ToField using generics.
class GToField f where
gtoField :: f p -> [Action]

instance GToField f => GToField (D1 i f) where
gtoField = gtoField . unM1

instance GToField f => GToField (C1 i f) where
gtoField = (:[]) . Many . tupleWrap . gtoField . unM1

instance (GToField f, GToField g) => GToField (f :*: g) where
gtoField (f :*: g) = gtoField f ++ gtoField g

instance (GToField f) => GToField (S1 i f) where
gtoField = gtoField . unM1

instance (ToField f) => GToField (Rec0 f) where
gtoField = (:[]) . toField . unK1

tupleWrap :: [Action] -> [Action]
tupleWrap xs = (Plain "("): (intersperse (Plain ",") xs) ++ [Plain ")"]

104 changes: 70 additions & 34 deletions test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE QuasiQuotes #-}

import Common
import Database.PostgreSQL.Simple.FromField (FromField)
import Database.PostgreSQL.Simple.Types(Query(..),Values(..))
import Database.PostgreSQL.Simple.ToField (ToField)
import Database.PostgreSQL.Simple.Types (Query(..), Values(..))
import Database.PostgreSQL.Simple.HStore
import Database.PostgreSQL.Simple.Copy
import Database.PostgreSQL.Simple.SqlQQ (sql)
import qualified Database.PostgreSQL.Simple.Transaction as ST

import Control.Applicative
Expand Down Expand Up @@ -42,25 +46,28 @@ tests :: TestEnv -> TestTree
tests env = testGroup "tests"
$ map ($ env)
[ testBytea
, testCase "ExecuteMany" . testExecuteMany
, testCase "Fold" . testFold
, testCase "Notify" . testNotify
, testCase "Serializable" . testSerializable
, testCase "Time" . testTime
, testCase "Array" . testArray
, testCase "Array of nullables" . testNullableArray
, testCase "HStore" . testHStore
, testCase "JSON" . testJSON
, testCase "Savepoint" . testSavepoint
, testCase "Unicode" . testUnicode
, testCase "Values" . testValues
, testCase "Copy" . testCopy
, testCase "ExecuteMany" . testExecuteMany
, testCase "Fold" . testFold
, testCase "Notify" . testNotify
, testCase "Serializable" . testSerializable
, testCase "Time" . testTime
, testCase "Array" . testArray
, testCase "Array of nullables" . testNullableArray
, testCase "HStore" . testHStore
, testCase "JSON" . testJSON
, testCase "Savepoint" . testSavepoint
, testCase "Unicode" . testUnicode
, testCase "Values" . testValues
, testCase "Copy" . testCopy
, testCopyFailures
, testCase "Double" . testDouble
, testCase "1-ary generic" . testGeneric1
, testCase "2-ary generic" . testGeneric2
, testCase "3-ary generic" . testGeneric3
, testCase "Timeout" . testTimeout
, testCase "Double" . testDouble
, testCase "1-ary generic row" . testGeneric1Row
, testCase "2-ary generic row" . testGeneric2Row
, testCase "3-ary generic row" . testGeneric3Row
, testCase "1-ary generic field" . testGeneric1Field
, testCase "2-ary generic field" . testGeneric2Field
, testCase "3-ary generic field" . testGeneric3Field
, testCase "Timeout" . testTimeout
]

testBytea :: TestEnv -> TestTree
Expand Down Expand Up @@ -406,44 +413,73 @@ testDouble TestEnv{..} = do
x @?= (-1 / 0)


testGeneric1 :: TestEnv -> Assertion
testGeneric1 TestEnv{..} = do
testGeneric1Row :: TestEnv -> Assertion
testGeneric1Row TestEnv{..} = do
roundTrip conn (Gen1 123)
where
roundTrip conn x0 = do
r <- query conn "SELECT ?::int" (x0 :: Gen1)
r @?= [x0]

testGeneric2 :: TestEnv -> Assertion
testGeneric2 TestEnv{..} = do
testGeneric2Row :: TestEnv -> Assertion
testGeneric2Row TestEnv{..} = do
roundTrip conn (Gen2 123 "asdf")
where
roundTrip conn x0 = do
r <- query conn "SELECT ?::int, ?::text" x0
r @?= [x0]

testGeneric3 :: TestEnv -> Assertion
testGeneric3 TestEnv{..} = do
testGeneric3Row :: TestEnv -> Assertion
testGeneric3Row TestEnv{..} = do
roundTrip conn (Gen3 123 "asdf" True)
where
roundTrip conn x0 = do
r <- query conn "SELECT ?::int, ?::text, ?::bool" x0
r @?= [x0]

testGeneric1Field :: TestEnv -> Assertion
testGeneric1Field TestEnv{..} = withTransaction conn $ do
-- It's not possible to simply roundtrip a 1-ary tuple
-- as PostgreSQL will treat it as a scalar value.
-- Therefore we will create a separate type for it.
execute_ conn "CREATE TYPE gen1 AS (x bigint)"
execute_ conn [sql|
CREATE FUNCTION test_gen1() RETURNS SETOF gen1 AS $$
(SELECT 1::bigint) UNION ALL (SELECT 2) UNION ALL (SELECT 3)
$$ LANGUAGE sql
|]
query_ conn "SELECT test_gen1()" >>= (@?= [Only (Gen1 1), Only (Gen1 2), Only (Gen1 3)])
rollback conn

testGeneric2Field :: TestEnv -> Assertion
testGeneric2Field TestEnv{..} = roundTripField conn (Gen2 123 "asdf")

testGeneric3Field :: TestEnv -> Assertion
testGeneric3Field TestEnv{..} = roundTripField conn (Gen3 123 "asdf" True)

roundTripField :: (Show a, Eq a, FromField a, ToField a) => Connection -> a -> Assertion
roundTripField conn x0 = query conn "SELECT ?" (Only x0) >>= (@?= [Only x0])

data Gen1 = Gen1 Int
deriving (Show,Eq,Generic)
instance FromRow Gen1
instance ToRow Gen1
deriving (Show, Eq, Generic, Typeable)
instance FromRow Gen1
instance ToRow Gen1
instance FromField Gen1
instance ToField Gen1

data Gen2 = Gen2 Int Text
deriving (Show,Eq,Generic)
instance FromRow Gen2
instance ToRow Gen2
deriving (Show, Eq, Generic, Typeable)
instance FromRow Gen2
instance ToRow Gen2
instance FromField Gen2
instance ToField Gen2

data Gen3 = Gen3 Int Text Bool
deriving (Show,Eq,Generic)
instance FromRow Gen3
instance ToRow Gen3
deriving (Show, Eq, Generic, Typeable)
instance FromRow Gen3
instance ToRow Gen3
instance FromField Gen3
instance ToField Gen3

data TestException
= TestException
Expand Down

0 comments on commit c41b9f8

Please sign in to comment.