diff --git a/CHANGELOG.md b/CHANGELOG.md index a2c1a7b..dd32bc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +# hpqtypes-extras-1.15.0.0 (2022-??-??) +* Add support for triggers and trigger functions. + # hpqtypes-extras-1.14.2.0 (2022-??-??) * Add support for GHC 9.2. * Drop support for GHC < 8.8. diff --git a/hpqtypes-extras.cabal b/hpqtypes-extras.cabal index 3cbe203..c7a1615 100644 --- a/hpqtypes-extras.cabal +++ b/hpqtypes-extras.cabal @@ -1,6 +1,6 @@ cabal-version: 2.2 name: hpqtypes-extras -version: 1.14.2.0 +version: 1.15.0.0 synopsis: Extra utilities for hpqtypes library description: The following extras for hpqtypes library: . @@ -68,6 +68,7 @@ library , Database.PostgreSQL.PQTypes.Model.Migration , Database.PostgreSQL.PQTypes.Model.PrimaryKey , Database.PostgreSQL.PQTypes.Model.Table + , Database.PostgreSQL.PQTypes.Model.Trigger , Database.PostgreSQL.PQTypes.SQL.Builder , Database.PostgreSQL.PQTypes.Versions @@ -111,6 +112,7 @@ test-suite hpqtypes-extras-tests ghc-options: -Wall build-depends: base + , containers , exceptions , hpqtypes , hpqtypes-extras diff --git a/src/Database/PostgreSQL/PQTypes/Checks.hs b/src/Database/PostgreSQL/PQTypes/Checks.hs index c426117..662b458 100644 --- a/src/Database/PostgreSQL/PQTypes/Checks.hs +++ b/src/Database/PostgreSQL/PQTypes/Checks.hs @@ -419,12 +419,14 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) indexes <- fetchMany fetchTableIndex runQuery_ $ sqlGetForeignKeys table fkeys <- fetchMany fetchForeignKey + triggers <- getDBTriggers tblName return $ mconcat [ checkColumns 1 tblColumns desc , checkPrimaryKey tblPrimaryKey pk , checkChecks tblChecks checks , checkIndexes tblIndexes indexes , checkForeignKeys tblForeignKeys fkeys + , checkTriggers tblTriggers triggers ] where fetchTableColumn @@ -541,6 +543,17 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) , checkNames (fkName tblName) fkeys ] + checkTriggers :: [Trigger] -> [Trigger] -> ValidationResult + checkTriggers defs triggers = + mapValidationResult id mapErrs $ checkEquality "TRIGGERs" defs triggers + where + mapErrs [] = [] + mapErrs errmsgs = errmsgs <> + [ "(HINT: If WHEN clauses are equal modulo number of parentheses, whitespace, \ + \case of variables or type casts used in conditions, just copy and paste \ + \expected output into source code.)" + ] + -- | Checks whether database is consistent, performing migrations if -- necessary. Requires all table names to be in lower case. -- @@ -601,6 +614,7 @@ checkDBConsistency options domains tablesWithVersions migrations = do validateMigrations :: m () validateMigrations = forM_ tables $ \table -> do + -- FIXME: https://github.com/scrive/hpqtypes-extras/issues/73 let presentMigrationVersions = [ mgrFrom | Migration{..} <- migrations , mgrTableName == tblName table ] diff --git a/src/Database/PostgreSQL/PQTypes/Migrate.hs b/src/Database/PostgreSQL/PQTypes/Migrate.hs index f2cf818..60a031e 100644 --- a/src/Database/PostgreSQL/PQTypes/Migrate.hs +++ b/src/Database/PostgreSQL/PQTypes/Migrate.hs @@ -1,7 +1,8 @@ module Database.PostgreSQL.PQTypes.Migrate ( createDomain, createTable, - createTableConstraints + createTableConstraints, + createTableTriggers ) where import Control.Monad @@ -28,6 +29,8 @@ createTable withConstraints table@Table{..} = do forM_ tblIndexes $ runQuery_ . sqlCreateIndexMaybeDowntime tblName -- Add all the other constraints if applicable. when withConstraints $ createTableConstraints table + -- Create triggers. + createTableTriggers table -- Register the table along with its version. runQuery_ . sqlInsert "table_versions" $ do sqlSet "name" (tblNameText table) @@ -42,3 +45,6 @@ createTableConstraints Table{..} = when (not $ null addConstraints) $ do , map sqlAddValidCheckMaybeDowntime tblChecks , map (sqlAddValidFKMaybeDowntime tblName) tblForeignKeys ] + +createTableTriggers :: MonadDB m => Table -> m () +createTableTriggers = mapM_ createTrigger . tblTriggers diff --git a/src/Database/PostgreSQL/PQTypes/Model.hs b/src/Database/PostgreSQL/PQTypes/Model.hs index f1e0aa3..978ea3c 100644 --- a/src/Database/PostgreSQL/PQTypes/Model.hs +++ b/src/Database/PostgreSQL/PQTypes/Model.hs @@ -9,6 +9,7 @@ module Database.PostgreSQL.PQTypes.Model ( , module Database.PostgreSQL.PQTypes.Model.Migration , module Database.PostgreSQL.PQTypes.Model.PrimaryKey , module Database.PostgreSQL.PQTypes.Model.Table + , module Database.PostgreSQL.PQTypes.Model.Trigger ) where import Database.PostgreSQL.PQTypes.Model.Check @@ -21,3 +22,4 @@ import Database.PostgreSQL.PQTypes.Model.Index import Database.PostgreSQL.PQTypes.Model.Migration import Database.PostgreSQL.PQTypes.Model.PrimaryKey import Database.PostgreSQL.PQTypes.Model.Table +import Database.PostgreSQL.PQTypes.Model.Trigger diff --git a/src/Database/PostgreSQL/PQTypes/Model/Table.hs b/src/Database/PostgreSQL/PQTypes/Model/Table.hs index 266e6f2..55ee4d8 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/Table.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/Table.hs @@ -25,6 +25,7 @@ import Database.PostgreSQL.PQTypes.Model.ColumnType import Database.PostgreSQL.PQTypes.Model.ForeignKey import Database.PostgreSQL.PQTypes.Model.Index import Database.PostgreSQL.PQTypes.Model.PrimaryKey +import Database.PostgreSQL.PQTypes.Model.Trigger data TableColumn = TableColumn { colName :: RawSQL () @@ -69,6 +70,7 @@ data Table = , tblChecks :: [Check] , tblForeignKeys :: [ForeignKey] , tblIndexes :: [TableIndex] +, tblTriggers :: [Trigger] , tblInitialSetup :: Maybe TableInitialSetup } @@ -86,6 +88,7 @@ tblTable = Table { , tblChecks = [] , tblForeignKeys = [] , tblIndexes = [] +, tblTriggers = [] , tblInitialSetup = Nothing } diff --git a/src/Database/PostgreSQL/PQTypes/Model/Trigger.hs b/src/Database/PostgreSQL/PQTypes/Model/Trigger.hs new file mode 100644 index 0000000..f134abe --- /dev/null +++ b/src/Database/PostgreSQL/PQTypes/Model/Trigger.hs @@ -0,0 +1,302 @@ +-- | +-- Module: Database.PostgreSQL.PQTypes.Model.Trigger +-- +-- Trigger name must be unique among triggers of same table. Only @CONSTRAINT@ triggers are +-- supported. They can only be run @AFTER@ an event. The associated functions are always +-- created with no arguments and always @RETURN TRIGGER@. +-- +-- For details, see . + +module Database.PostgreSQL.PQTypes.Model.Trigger ( + -- * Trigger functions + TriggerFunction(..) + , sqlCreateTriggerFunction + , sqlDropTriggerFunction + -- * Triggers + , TriggerEvent(..) + , Trigger(..) + , triggerMakeName + , triggerBaseName + , sqlCreateTrigger + , sqlDropTrigger + , createTrigger + , dropTrigger + , getDBTriggers + ) where + +import Data.Bits (testBit) +import Data.Foldable (foldl') +import Data.Int +import Data.Monoid.Utils +import Data.Set (Set) +import Data.Text (Text) +import Database.PostgreSQL.PQTypes +import Database.PostgreSQL.PQTypes.SQL.Builder +import qualified Data.Set as Set +import qualified Data.Text as Text + +-- | Function associated with a trigger. +-- +-- @since 1.15.0.0 +data TriggerFunction = TriggerFunction { + tfName :: RawSQL () + -- ^ The function's name. + , tfSource :: RawSQL () + -- ^ The functions's body source code. +} deriving (Show) + +instance Eq TriggerFunction where + -- Since the functions have no arguments, it's impossible to create two functions with + -- the same name. Therefore comparing functions only by their names is enough in this + -- case. The assumption is, of course, that the database schema is only changed using + -- this framework. + f1 == f2 = tfName f1 == tfName f2 + +-- | Build an SQL statement for creating a trigger function. +-- +-- Since we only support @CONSTRAINT@ triggers, the function will always @RETURN TRIGGER@ +-- and will have no parameters. +-- +-- @since 1.15.0.0 +sqlCreateTriggerFunction :: TriggerFunction -> RawSQL () +sqlCreateTriggerFunction TriggerFunction{..} = + "CREATE FUNCTION" + <+> tfName + <> "()" + <+> "RETURNS TRIGGER" + <+> "AS $$" + <+> tfSource + <+> "$$" + <+> "LANGUAGE PLPGSQL" + <+> "VOLATILE" + <+> "RETURNS NULL ON NULL INPUT" + +-- | Build an SQL statement for dropping a trigger function. +-- +-- @since 1.15.0.0 +sqlDropTriggerFunction :: TriggerFunction -> RawSQL () +sqlDropTriggerFunction TriggerFunction{..} = + "DROP FUNCTION" <+> tfName <+> "RESTRICT" + +-- | Trigger event name. +-- +-- @since 1.15.0.0 +data TriggerEvent + = TriggerInsert + -- ^ The @INSERT@ event. + | TriggerUpdate + -- ^ The @UPDATE@ event. + | TriggerUpdateOf [RawSQL ()] + -- ^ The @UPDATE OF column1 [, column2 ...]@ event. + | TriggerDelete + -- ^ The @DELETE@ event. + deriving (Eq, Ord, Show) + +-- | Trigger. +-- +-- @since 1.15.0.0 +data Trigger = Trigger { + triggerTable :: RawSQL () + -- ^ The table that the trigger is associated with. + , triggerName :: RawSQL () + -- ^ The internal name without any prefixes. Trigger name must be unique among + -- triggers of same table. See 'triggerMakeName'. + , triggerEvents :: Set TriggerEvent + -- ^ The set of events. Corresponds to the @{ __event__ [ OR ... ] }@ in the trigger + -- definition. The order in which they are defined doesn't matter and there can + -- only be one of each. + , triggerDeferrable :: Bool + -- ^ Is the trigger @DEFERRABLE@ or @NOT DEFERRABLE@ ? + , triggerInitiallyDeferred :: Bool + -- ^ Is the trigger @INITIALLY DEFERRED@ or @INITIALLY IMMEDIATE@ ? + , triggerWhen :: Maybe (RawSQL ()) + -- ^ The condition that specifies whether the trigger should fire. Corresponds to the + -- @WHEN ( __condition__ )@ in the trigger definition. + , triggerFunction :: TriggerFunction + -- ^ The function to execute when the trigger fires. +} deriving (Eq, Show) + +-- | Make a trigger name that can be used in SQL. +-- +-- Given a base @name@ and @tableName@, return a new name that will be used as the +-- actual name of the trigger in an SQL query. The returned name is in the format +-- @trg\__\\__\@. +-- +-- @since 1.15.0 +triggerMakeName :: RawSQL () -> RawSQL () -> RawSQL () +triggerMakeName name tableName = "trg__" <> tableName <> "__" <> name + +-- | Return the trigger's base name. +-- +-- Given the trigger's actual @name@ and @tableName@, return the base name of the +-- trigger. This is basically the reverse of what 'triggerMakeName' does. +-- +-- @since 1.15.0 +triggerBaseName :: RawSQL () -> RawSQL () -> RawSQL () +triggerBaseName name tableName = + rawSQL (snd . Text.breakOnEnd (unRawSQL tableName <> "__") $ unRawSQL name) () + +triggerEventName :: TriggerEvent -> RawSQL () +triggerEventName = \case + TriggerInsert -> "INSERT" + TriggerUpdate -> "UPDATE" + TriggerUpdateOf columns -> if null columns + then error "UPDATE OF must have columns." + else "UPDATE OF" <+> mintercalate ", " columns + TriggerDelete -> "DELETE" + +-- | Build an SQL statement that creates a trigger. +-- +-- Only supports @CONSTRAINT@ triggers which can only run @AFTER@. +-- +-- @since 1.15.0 +sqlCreateTrigger :: Trigger -> RawSQL () +sqlCreateTrigger Trigger{..} = + "CREATE CONSTRAINT TRIGGER" <+> trgName + <+> "AFTER" <+> trgEvents + <+> "ON" <+> triggerTable + <+> trgTiming + <+> "FOR EACH ROW" + <+> trgWhen + <+> "EXECUTE FUNCTION" <+> trgFunction + <+> "()" + where + trgName + | triggerName == "" = error "Trigger must have a name." + | otherwise = triggerMakeName triggerName triggerTable + trgEvents + | triggerEvents == Set.empty = error "Trigger must have at least one event." + | otherwise = mintercalate " OR " . map triggerEventName $ Set.toList triggerEvents + trgTiming = let deferrable = (if triggerDeferrable then "" else "NOT") <+> "DEFERRABLE" + deferred = if triggerInitiallyDeferred + then "INITIALLY DEFERRED" + else "INITIALLY IMMEDIATE" + in deferrable <+> deferred + trgWhen = maybe "" (\w -> "WHEN (" <+> w <+> ")") triggerWhen + trgFunction = tfName triggerFunction + + +-- | Build an SQL statement that drops a trigger. +-- +-- @since 1.15.0 +sqlDropTrigger :: Trigger -> RawSQL () +sqlDropTrigger Trigger{..} = + -- In theory, because the trigger is dependent on its function, it should be enough to + -- 'DROP FUNCTION triggerFunction CASCADE'. However, let's make this safe and go with + -- the default RESTRICT here. + "DROP TRIGGER" <+> trgName <+> "ON" <+> triggerTable <+> "RESTRICT" + where + trgName + | triggerName == "" = error "Trigger must have a name." + | otherwise = triggerMakeName triggerName triggerTable + +-- | Create the trigger in the database. +-- +-- First, create the trigger's associated function, then create the trigger itself. +-- +-- @since 1.15.0 +createTrigger :: MonadDB m => Trigger -> m () +createTrigger trigger = do + -- TODO: Use 'withTransaction' here? That would mean adding MonadMask... + runQuery_ . sqlCreateTriggerFunction $ triggerFunction trigger + runQuery_ $ sqlCreateTrigger trigger + +-- | Drop the trigger from the database. +-- +-- @since 1.15.0 +dropTrigger :: MonadDB m => Trigger -> m () +dropTrigger trigger = do + -- First, drop the trigger, as it is dependent on the function. See the comment in + -- 'sqlDropTrigger'. + -- TODO: Use 'withTransaction' here? That would mean adding MonadMask... + runQuery_ $ sqlDropTrigger trigger + runQuery_ . sqlDropTriggerFunction $ triggerFunction trigger + +-- | Get all noninternal triggers from the database. +-- +-- Run a query that returns all triggers associated with the given table and marked as +-- @tgisinternal = false@. +-- +-- Note that, in the background, to get the trigger's @WHEN@ clause and the source code of +-- the attached function, the entire query that had created the trigger is received using +-- @pg_get_triggerdef(t.oid, true)::text@ and then parsed. The result of that call will be +-- decompiled and normalized, which means that it's likely not what the user had +-- originally typed. +-- +-- @since 1.15.0 +getDBTriggers :: forall m. MonadDB m => RawSQL () -> m [Trigger] +getDBTriggers tableName = do + runQuery_ . sqlSelect "pg_trigger t" $ do + sqlResult "t.tgname::text" -- name + sqlResult "t.tgtype" -- smallint == int2 => (2 bytes) + sqlResult "t.tgdeferrable" -- boolean + sqlResult "t.tginitdeferred"-- boolean + -- This gets the entire query that created this trigger. Note that it's decompiled and + -- normalized, which means that it's likely not what the user actually typed. For + -- example, if the original query had excessive whitespace in it, it won't be in this + -- result. + sqlResult "pg_get_triggerdef(t.oid, true)::text" + sqlResult "p.proname::text" -- name + sqlResult "p.prosrc" -- text + sqlResult "c.relname::text" + sqlJoinOn "pg_proc p" "t.tgfoid = p.oid" + sqlJoinOn "pg_class c" "c.oid = t.tgrelid" + sqlWhereEq "t.tgisinternal" False + sqlWhereEq "c.relname" $ unRawSQL tableName + fetchMany getTrigger + where + getTrigger :: (String, Int16, Bool, Bool, String, String, String, String) -> Trigger + getTrigger (tgname, tgtype, tgdeferrable, tginitdeferrable, triggerdef, proname, prosrc, tblName) = + Trigger { triggerTable = tableName' + , triggerName = triggerBaseName (unsafeSQL tgname) tableName' + , triggerEvents = trgEvents + , triggerDeferrable = tgdeferrable + , triggerInitiallyDeferred = tginitdeferrable + , triggerWhen = tgrWhen + , triggerFunction = TriggerFunction (unsafeSQL proname) (unsafeSQL prosrc) + } + where + tableName' :: RawSQL () + tableName' = unsafeSQL tblName + + parseBetween :: Text -> Text -> Maybe (RawSQL ()) + parseBetween left right = + let (prefix, match) = Text.breakOnEnd left $ Text.pack triggerdef + in if Text.null prefix + then Nothing + else Just $ (rawSQL . fst $ Text.breakOn right match) () + + -- Get the WHEN part of the query. Anything between WHEN and EXECUTE is what we + -- want. The Postgres' grammar guarantees that WHEN and EXECUTE are always next to + -- each other and in that order. + tgrWhen :: Maybe (RawSQL ()) + tgrWhen = parseBetween "WHEN (" ") EXECUTE" + + -- Similarly, in case of UPDATE OF, the columns can be simply parsed from the + -- original query. Note that UPDATE and UPDATE OF are mutually exclusive and have + -- the same bit set in the underlying tgtype bit field. + trgEvents :: Set TriggerEvent + trgEvents = + foldl' (\set (mask, event) -> + if testBit tgtype mask + then + Set.insert + (if event == TriggerUpdate + then maybe event trgUpdateOf $ parseBetween "UPDATE OF " " ON" + else event + ) + set + else set + ) + Set.empty + -- Taken from PostgreSQL sources: src/include/catalog/pg_trigger.h: + [ (2, TriggerInsert) -- #define TRIGGER_TYPE_INSERT (1 << 2) + , (3, TriggerDelete) -- #define TRIGGER_TYPE_DELETE (1 << 3) + , (4, TriggerUpdate) -- #define TRIGGER_TYPE_UPDATE (1 << 4) + ] + + trgUpdateOf :: RawSQL () -> TriggerEvent + trgUpdateOf columnsSQL = + let columns = map (unsafeSQL . Text.unpack) . Text.splitOn ", " $ unRawSQL columnsSQL + in TriggerUpdateOf columns + diff --git a/test/Main.hs b/test/Main.hs index 98f1383..58c7575 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -5,8 +5,10 @@ import Control.Monad.IO.Class import Data.Either import Data.Typeable import Data.UUID.Types +import qualified Data.Set as Set import qualified Data.Text as T +import Data.Monoid.Utils import Database.PostgreSQL.PQTypes import Database.PostgreSQL.PQTypes.Checks import Database.PostgreSQL.PQTypes.Model.ColumnType @@ -16,6 +18,7 @@ import Database.PostgreSQL.PQTypes.Model.Index import Database.PostgreSQL.PQTypes.Model.Migration import Database.PostgreSQL.PQTypes.Model.PrimaryKey import Database.PostgreSQL.PQTypes.Model.Table +import Database.PostgreSQL.PQTypes.Model.Trigger import Database.PostgreSQL.PQTypes.SQL.Builder import Log import Log.Backend.StandardOutput @@ -69,6 +72,7 @@ tableBankSchema1 = , colNullable = False } ] , tblPrimaryKey = pkOnColumn "id" + , tblTriggers = [] } tableBankSchema2 :: Table @@ -368,7 +372,8 @@ schema2Migrations :: (MonadDB m) => [Migration m] schema2Migrations = schema1Migrations ++ [ dropTableMigration tableWitnessedRobberySchema1 , dropTableMigration tableWitnessSchema1 - , createTableMigration tableUnderArrestSchema2 ] + , createTableMigration tableUnderArrestSchema2 + ] schema3Tables :: [Table] schema3Tables = [ tableBankSchema3 @@ -826,6 +831,325 @@ migrationTest1Body step = do migrateDBToSchema5 step testDBSchema5 step +bankTrigger1 :: Trigger +bankTrigger1 = + Trigger { triggerTable = "bank" + , triggerName = "trigger_1" + , triggerEvents = Set.fromList [TriggerInsert] + , triggerDeferrable = False + , triggerInitiallyDeferred = False + , triggerWhen = Nothing + , triggerFunction = TriggerFunction "function_1" $ + "begin" + <+> " perform true;" + <+> " return null;" + <+> "end;" + } + +bankTrigger2 :: Trigger +bankTrigger2 = + bankTrigger1 + { triggerFunction = TriggerFunction "function_2" $ + "begin" + <+> " return null;" + <+> "end;" + } + +bankTrigger3 :: Trigger +bankTrigger3 = + Trigger { triggerTable = "bank" + , triggerName = "trigger_3" + , triggerEvents = Set.fromList [TriggerInsert, TriggerUpdateOf [unsafeSQL "location"]] + , triggerDeferrable = True + , triggerInitiallyDeferred = True + , triggerWhen = Nothing + , triggerFunction = TriggerFunction "function_3" $ + "begin" + <+> " perform true;" + <+> " return null;" + <+> "end;" + } + +bankTrigger2Proper :: Trigger +bankTrigger2Proper = + bankTrigger2 { triggerName = "trigger_2" } + +testTriggers :: HasCallStack => (String -> TestM ()) -> TestM () +testTriggers step = do + step "Running trigger tests..." + + step "create the initial database" + migrate [tableBankSchema1] [createTableMigration tableBankSchema1] + + do + let msg = "checkDatabase fails if there are triggers in the database but not in the schema" + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [] + } + ] + ms = [ createTriggerMigration 1 bankTrigger1 ] + step msg + assertException msg $ migrate ts ms + + do + let msg = "checkDatabase fails if there are triggers in the schema but not in the database" + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [bankTrigger1] + } + ] + ms = [] + triggerStep msg $ do + assertException msg $ migrate ts ms + + do + let msg = "test succeeds when creating a single trigger" + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [bankTrigger1] + } + ] + ms = [ createTriggerMigration 1 bankTrigger1 ] + triggerStep msg $ do + assertNoException msg $ migrate ts ms + verify [bankTrigger1] True + + do + let msg = "checkDatabase fails if triggers differ in function name" + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [bankTrigger1] + } + ] + ms = [ createTriggerMigration 1 bankTrigger2 ] + triggerStep msg $ do + assertException msg $ migrate ts ms + + do + -- Attempt to create the same triggers twice. Should fail with a DBException saying + -- that function already exists. + let msg = "database exception is raised if trigger is created twice" + ts = [ tableBankSchema1 { tblVersion = 3 + , tblTriggers = [bankTrigger1] + } + ] + ms = [ createTriggerMigration 1 bankTrigger1 + , createTriggerMigration 2 bankTrigger1 + ] + triggerStep msg $ do + assertDBException msg $ migrate ts ms + + do + let msg = "database exception is raised if triggers only differ in function name" + ts = [ tableBankSchema1 { tblVersion = 3 + , tblTriggers = [bankTrigger1, bankTrigger2] + } + ] + ms = [ createTriggerMigration 1 bankTrigger1 + , createTriggerMigration 2 bankTrigger2 + ] + triggerStep msg $ do + assertDBException msg $ migrate ts ms + + do + let msg = "successfully migrate two triggers" + ts = [ tableBankSchema1 { tblVersion = 3 + , tblTriggers = [bankTrigger1, bankTrigger2Proper] + } + ] + ms = [ createTriggerMigration 1 bankTrigger1 + , createTriggerMigration 2 bankTrigger2Proper + ] + triggerStep msg $ do + assertNoException msg $ migrate ts ms + verify [bankTrigger1, bankTrigger2Proper] True + + do + let msg = "database exception is raised if trigger's WHEN is syntactically incorrect" + trg = bankTrigger1 { triggerWhen = Just "WILL FAIL" } + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [ createTriggerMigration 1 trg ] + triggerStep msg $ do + assertDBException msg $ migrate ts ms + + do + let msg = "database exception is raised if trigger's WHEN uses undefined column" + trg = bankTrigger1 { triggerWhen = Just "NEW.foobar = 1" } + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [ createTriggerMigration 1 trg ] + triggerStep msg $ do + assertDBException msg $ migrate ts ms + + do + -- This trigger is valid. However, the WHEN clause specified in triggerWhen is not + -- what gets returned from the database. The decompiled and normalized WHEN clause + -- from the database looks like this: + -- new.name <> 'foobar'::text + -- We simply assert an exception, which presumably comes from the migration framework, + -- while it should actually be a deeper check for just the differing WHEN + -- clauses. On the other hand, it's probably good enough as it is. + -- See the comment for 'getDBTriggers' in src/Database/PostgreSQL/PQTypes/Model/Trigger.hs. + let msg = "checkDatabase fails if WHEN clauses from database and code differ" + trg = bankTrigger1 { triggerWhen = Just "NEW.name != 'foobar'" } + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [ createTriggerMigration 1 trg ] + triggerStep msg $ do + assertException msg $ migrate ts ms + + do + let msg = "successfully migrate trigger with valid WHEN" + trg = bankTrigger1 { triggerWhen = Just "new.name <> 'foobar'::text" } + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [ createTriggerMigration 1 trg ] + triggerStep msg $ do + assertNoException msg $ migrate ts ms + verify [trg] True + + do + let msg = "successfully migrate trigger that is deferrable" + trg = bankTrigger1 { triggerDeferrable = True } + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [ createTriggerMigration 1 trg ] + triggerStep msg $ do + assertNoException msg $ migrate ts ms + verify [trg] True + + do + let msg = "successfully migrate trigger that is deferrable and initially deferred" + trg = bankTrigger1 { triggerDeferrable = True + , triggerInitiallyDeferred = True + } + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [ createTriggerMigration 1 trg ] + triggerStep msg $ do + assertNoException msg $ migrate ts ms + verify [trg] True + + do + let msg = "database exception is raised if trigger is initially deferred but not deferrable" + trg = bankTrigger1 { triggerDeferrable = False + , triggerInitiallyDeferred = True + } + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [ createTriggerMigration 1 trg ] + triggerStep msg $ do + assertDBException msg $ migrate ts ms + + do + let msg = "database exception is raised if dropping trigger that does not exist" + trg = bankTrigger1 + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [ dropTriggerMigration 1 trg ] + triggerStep msg $ do + assertDBException msg $ migrate ts ms + + do + let msg = "database exception is raised if dropping trigger function of which does not exist" + trg = bankTrigger2 + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [ dropTriggerMigration 1 trg ] + triggerStep msg $ do + assertDBException msg $ migrate ts ms + + do + let msg = "successfully drop trigger" + trg = bankTrigger1 + ts = [ tableBankSchema1 { tblVersion = 3 + , tblTriggers = [] + } + ] + ms = [ createTriggerMigration 1 trg, dropTriggerMigration 2 trg ] + triggerStep msg $ do + assertNoException msg $ migrate ts ms + verify [trg] False + + do + let msg = "database exception is raised if dropping trigger twice" + trg = bankTrigger2 + ts = [ tableBankSchema1 { tblVersion = 3 + , tblTriggers = [trg] + } + ] + ms = [ dropTriggerMigration 1 trg, dropTriggerMigration 2 trg ] + triggerStep msg $ do + assertDBException msg $ migrate ts ms + + do + let msg = "successfully create trigger with multiple events" + trg = bankTrigger3 + ts = [ tableBankSchema1 { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [ createTriggerMigration 1 trg ] + triggerStep msg $ do + assertNoException msg $ migrate ts ms + verify [trg] True + + where + triggerStep msg rest = do + recreateTriggerDB + step msg + rest + + migrate tables migrations = do + migrateDatabase defaultExtrasOptions ["pgcrypto"] [] [] tables migrations + checkDatabase defaultExtrasOptions [] [] tables + + -- Verify that the given triggers are (not) present in the database. + verify :: (MonadIO m, MonadDB m, HasCallStack) => [Trigger] -> Bool -> m () + verify triggers present = do + dbTriggers <- getDBTriggers "bank" + let ok = and $ map (`elem` dbTriggers) triggers + err = "Triggers " <> (if present then "" else "not ") <> "present in the database." + trans = if present then id else not + liftIO . assertBool err $ trans ok + + triggerMigration :: MonadDB m => (Trigger -> m ()) -> Int -> Trigger -> Migration m + triggerMigration fn from trg = Migration + { mgrTableName = tblName tableBankSchema1 + , mgrFrom = fromIntegral from + , mgrAction = StandardMigration $ fn trg + } + + createTriggerMigration :: MonadDB m => Int -> Trigger -> Migration m + createTriggerMigration = triggerMigration createTrigger + + dropTriggerMigration :: MonadDB m => Int -> Trigger -> Migration m + dropTriggerMigration = triggerMigration dropTrigger + + recreateTriggerDB = do + runSQL_ "DROP TRIGGER IF EXISTS trg__bank__trigger_1 ON bank;" + runSQL_ "DROP TRIGGER IF EXISTS trg__bank__trigger_2 ON bank;" + runSQL_ "DROP FUNCTION IF EXISTS function_1;" + runSQL_ "DROP FUNCTION IF EXISTS function_2;" + runSQL_ "DROP TABLE IF EXISTS bank;" + runSQL_ "DELETE FROM table_versions WHERE name = 'bank'"; + migrate [tableBankSchema1] [createTableMigration tableBankSchema1] migrationTest1 :: ConnectionSourceM (LogT IO) -> TestTree migrationTest1 connSource = @@ -834,8 +1158,6 @@ migrationTest1 connSource = migrationTest1Body step - -- freshTestDB step - -- | Test for behaviour of 'checkDatabase' and 'checkDatabaseAllowUnknownObjects' migrationTest2 :: ConnectionSourceM (LogT IO) -> TestTree migrationTest2 connSource = @@ -957,6 +1279,13 @@ migrationTest4 connSource = freshTestDB step +-- | Test triggers. +triggerTests :: ConnectionSourceM (LogT IO) -> TestTree +triggerTests connSource = + testCaseSteps' "Trigger tests" connSource $ \step -> do + freshTestDB step + testTriggers step + eitherExc :: MonadCatch m => (SomeException -> m ()) -> (a -> m ()) -> m a -> m () eitherExc left right c = try c >>= either left right @@ -966,10 +1295,15 @@ assertNoException t c = eitherExc (const $ return ()) c assertException :: String -> TestM () -> TestM () -assertException t c = eitherExc +assertException t c = eitherExc (const $ return ()) (const $ liftIO $ assertFailure ("No exception thrown for: " ++ t)) c +assertDBException :: String -> TestM () -> TestM () +assertDBException t c = + try c >>= either (\DBException{} -> pure ()) + (const . liftIO . assertFailure $ "No DBException thrown for: " ++ t) + -- | A variant of testCaseSteps that works in TestM monad. testCaseSteps' :: TestName -> ConnectionSourceM (LogT IO) -> ((String -> TestM ()) -> TestM ()) @@ -994,6 +1328,7 @@ main = do , migrationTest2 connSource , migrationTest3 connSource , migrationTest4 connSource + , triggerTests connSource ] where ings =