Skip to content

Commit

Permalink
feat: apply all function settings as transaction-scoped settings
Browse files Browse the repository at this point in the history
  • Loading branch information
taimoorzaeem committed Feb 5, 2024
1 parent 45cabac commit f354c88
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
### Added

- #2887, Add Preference `max-affected` to limit affected resources - @taimoorzaeem
- #3061, Apply all function settings as transaction-scoped settings - @taimoorzaeem

### Fixed

Expand Down
6 changes: 3 additions & 3 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A

(ActionInvoke invMethod, TargetProc identifier _) -> do
(planTime', cPlan) <- withTiming $ liftEither $ Plan.callReadPlan identifier conf sCache apiReq invMethod
(txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdTimeout $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
(txTime', resultSet) <- withTiming $ runQuery (fromMaybe roleIsoLvl $ pdIsoLvl (Plan.crProc cPlan)) (pdFuncSetting $ Plan.crProc cPlan) (Plan.crTxMode cPlan) $ Query.invokeQuery (Plan.crProc cPlan) cPlan apiReq conf pgVer
(respTime', pgrst) <- withTiming $ liftEither $ Response.invokeResponse cPlan invMethod (Plan.crProc cPlan) apiReq resultSet
return $ pgrstResponse (ServerTiming jwtTime parseTime planTime' txTime' respTime') pgrst

Expand Down Expand Up @@ -230,9 +230,9 @@ handleRequest AuthResult{..} conf appState authenticated prepared pgVer apiReq@A
where
roleSettings = fromMaybe mempty (HM.lookup authRole $ configRoleSettings conf)
roleIsoLvl = HM.findWithDefault SQL.ReadCommitted authRole $ configRoleIsoLvl conf
runQuery isoLvl timeout mode query =
runQuery isoLvl funcSet mode query =
runDbHandler appState conf isoLvl mode authenticated prepared $ do
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) apiReq timeout
Query.setPgLocals conf authClaims authRole (HM.toList roleSettings) funcSet apiReq
Query.runPreReq conf
query

Expand Down
8 changes: 4 additions & 4 deletions src/PostgREST/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ optionalRollback AppConfig{..} ApiRequest{iPreferences=Preferences{..}} = do

-- | Set transaction scoped settings
setPgLocals :: AppConfig -> KM.KeyMap JSON.Value -> BS.ByteString -> [(ByteString, ByteString)] ->
ApiRequest -> Maybe Text -> DbHandler ()
setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
Maybe (Text,Text) -> ApiRequest -> DbHandler ()
setPgLocals AppConfig{..} claims role roleSettings funcSetting ApiRequest{..} = lift $
SQL.statement mempty $ SQL.dynamicallyParameterized
-- To ensure `GRANT SET ON PARAMETER <superuser_setting> TO authenticator` works, the role settings must be set before the impersonated role.
-- Otherwise the GRANT SET would have to be applied to the impersonated role. See https://github.com/PostgREST/postgrest/issues/3045
("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ timeoutSql ++ appSettingsSql))
("select " <> intercalateSnippet ", " (searchPathSql : roleSettingsSql ++ roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ timezoneSql ++ funcSettingSql ++ appSettingsSql))
HD.noResult configDbPreparedStatements
where
methodSql = setConfigWithConstantName ("request.method", iMethod)
Expand All @@ -264,7 +264,7 @@ setPgLocals AppConfig{..} claims role roleSettings ApiRequest{..} tout = lift $
roleSettingsSql = setConfigWithDynamicName <$> roleSettings
appSettingsSql = setConfigWithDynamicName <$> (join bimap toUtf8 <$> configAppSettings)
timezoneSql = maybe mempty (\(PreferTimezone tz) -> [setConfigWithConstantName ("timezone", tz)]) $ preferTimezone iPreferences
timeoutSql = maybe mempty ((\t -> [setConfigWithConstantName ("statement_timeout", t)]) . encodeUtf8) tout
funcSettingSql = maybe mempty (\(key,val) -> [setConfigWithDynamicName (encodeUtf8 key, encodeUtf8 val)]) funcSetting
searchPathSql =
let schemas = escapeIdentList (iSchema : configDbExtraSearchPath) in
setConfigWithConstantName ("search_path", schemas)
Expand Down
19 changes: 14 additions & 5 deletions src/PostgREST/SchemaCache.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import qualified Data.Aeson.Types as JSON
import qualified Data.HashMap.Strict as HM
import qualified Data.HashMap.Strict.InsOrd as HMI
import qualified Data.Set as S
import qualified Data.Text as T
import qualified Hasql.Decoders as HD
import qualified Hasql.Encoders as HE
import qualified Hasql.Statement as SQL
Expand All @@ -59,7 +60,8 @@ import PostgREST.SchemaCache.Relationship (Cardinality (..),
RelationshipsMap)
import PostgREST.SchemaCache.Representations (DataRepresentation (..),
RepresentationsMap)
import PostgREST.SchemaCache.Routine (FuncVolatility (..),
import PostgREST.SchemaCache.Routine (FuncSetting,
FuncVolatility (..),
MediaHandler (..),
MediaHandlerMap,
PgType (..),
Expand All @@ -74,7 +76,6 @@ import qualified PostgREST.MediaType as MediaType

import Protolude


data SchemaCache = SchemaCache
{ dbTables :: TablesMap
, dbRelationships :: RelationshipsMap
Expand Down Expand Up @@ -297,7 +298,7 @@ decodeFuncs =
<*> (parseVolatility <$> column HD.char)
<*> column HD.bool
<*> nullableColumn (toIsolationLevel <$> HD.text)
<*> nullableColumn HD.text
<*> nullableColumn (parseFuncSetting <$> HD.text) -- function setting

addKey :: Routine -> (QualifiedIdentifier, Routine)
addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd)
Expand All @@ -317,6 +318,12 @@ decodeFuncs =
| v == 's' = Stable
| otherwise = Volatile -- only 'v' can happen here

parseFuncSetting :: Text -> FuncSetting
parseFuncSetting txt = toTuple $ T.splitOn "=" txt
where
toTuple [x,y] = (x,y)
toTuple _ = ("error","parsing")

decodeRepresentations :: HD.Result RepresentationsMap
decodeRepresentations =
HM.fromList . map (\rep@DataRepresentation{drSourceType, drTargetType} -> ((drSourceType, drTargetType), rep)) <$> HD.rowList row
Expand Down Expand Up @@ -432,7 +439,10 @@ funcsSqlQuery pgVer = [q|
p.provolatile,
p.provariadic > 0 as hasvariadic,
lower((regexp_split_to_array((regexp_split_to_array(iso_config, '='))[2], ','))[1]) AS transaction_isolation_level,
lower((regexp_split_to_array((regexp_split_to_array(timeout_config, '='))[2], ','))[1]) AS statement_timeout
-- the proconfig is returned as text[], which returns something like
-- e.g "\NUL\NUL\DB4statement_timeout=1s", this is solved by using
-- selecting first element
(p.proconfig)[1] AS func_setting
FROM pg_proc p
LEFT JOIN arguments a ON a.oid = p.oid
JOIN pg_namespace pn ON pn.oid = p.pronamespace
Expand All @@ -442,7 +452,6 @@ funcsSqlQuery pgVer = [q|
LEFT JOIN pg_class comp ON comp.oid = t.typrelid
LEFT JOIN pg_description as d ON d.objoid = p.oid
LEFT JOIN LATERAL unnest(proconfig) iso_config ON iso_config like 'default_transaction_isolation%'
LEFT JOIN LATERAL unnest(proconfig) timeout_config ON timeout_config like 'statement_timeout%'
WHERE t.oid <> 'trigger'::regtype AND COALESCE(a.callable, true)
|] <> (if pgVer >= pgVersion110 then "AND prokind = 'f'" else "AND NOT (proisagg OR proiswindow)")

Expand Down
13 changes: 8 additions & 5 deletions src/PostgREST/SchemaCache/Routine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ module PostgREST.SchemaCache.Routine
, Routine(..)
, RoutineParam(..)
, FuncVolatility(..)
, FuncSetting
, RoutineMap
, RetType(..)
, funcReturnsScalar
Expand Down Expand Up @@ -49,6 +50,8 @@ data FuncVolatility
| Immutable
deriving (Eq, Show, Ord, Generic, JSON.ToJSON)

type FuncSetting = (Text,Text)

data Routine = Function
{ pdSchema :: Schema
, pdName :: Text
Expand All @@ -58,12 +61,12 @@ data Routine = Function
, pdVolatility :: FuncVolatility
, pdHasVariadic :: Bool
, pdIsoLvl :: Maybe SQL.IsolationLevel
, pdTimeout :: Maybe Text
, pdFuncSetting :: Maybe FuncSetting
}
deriving (Eq, Show, Generic)
-- need to define JSON manually bc SQL.IsolationLevel doesn't have a JSON instance(and we can't define one for that type without getting a compiler error)
instance JSON.ToJSON Routine where
toJSON (Function sch nam desc params ret vol hasVar _ tout) = JSON.object
toJSON (Function sch nam desc params ret vol hasVar _ set) = JSON.object
[
"pdSchema" .= sch
, "pdName" .= nam
Expand All @@ -72,7 +75,7 @@ instance JSON.ToJSON Routine where
, "pdReturnType" .= JSON.toJSON ret
, "pdVolatility" .= JSON.toJSON vol
, "pdHasVariadic" .= JSON.toJSON hasVar
, "pdTimeout" .= tout
, "pdFuncSetting" .= JSON.toJSON set
]

data RoutineParam = RoutineParam
Expand All @@ -86,10 +89,10 @@ data RoutineParam = RoutineParam

-- Order by least number of params in the case of overloaded functions
instance Ord Routine where
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 tout1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 tout2
Function schema1 name1 des1 prms1 rt1 vol1 hasVar1 iso1 set1 `compare` Function schema2 name2 des2 prms2 rt2 vol2 hasVar2 iso2 set2
| schema1 == schema2 && name1 == name2 && length prms1 < length prms2 = LT
| schema2 == schema2 && name1 == name2 && length prms1 > length prms2 = GT
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, tout1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, tout2)
| otherwise = (schema1, name1, des1, prms1, rt1, vol1, hasVar1, iso1, set1) `compare` (schema2, name2, des2, prms2, rt2, vol2, hasVar2, iso2, set2)

-- | A map of all procs, all of which can be overloaded(one entry will have more than one Routine).
-- | It uses a HashMap for a faster lookup.
Expand Down
6 changes: 6 additions & 0 deletions test/io/fixtures.sql
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,9 @@ $$ language sql set statement_timeout = '4s';
create function get_postgres_version() returns int as $$
select current_setting('server_version_num')::int;
$$ language sql;

GRANT SET ON PARAMETER log_min_duration_sample TO postgrest_test_anonymous;

create or replace function log_min_duration_test() returns text as $$
select current_setting('log_min_duration_sample',false);
$$ language sql set log_min_duration_sample = '5s';
53 changes: 31 additions & 22 deletions test/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,28 +1327,6 @@ def test_no_preflight_request_with_CORS_config_should_not_return_header(defaulte
assert "Access-Control-Allow-Origin" not in response.headers


def test_fail_with_3_sec_statement_and_1_sec_statement_timeout(defaultenv):
"statement that takes three seconds to execute should fail with one second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/one_sec_timeout")

assert response.status_code == 500
assert (
response.text
== '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}'
)


def test_passes_with_3_sec_statement_and_4_sec_statement_timeout(defaultenv):
"statement that takes three seconds to execute should succeed with four second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/four_sec_timeout")

assert response.status_code == 204


@pytest.mark.parametrize("level", ["crit", "error", "warn", "info"])
def test_db_error_logging_to_stderr(level, defaultenv, metapostgrest):
"verify that DB errors are logged to stderr"
Expand All @@ -1375,3 +1353,34 @@ def test_db_error_logging_to_stderr(level, defaultenv, metapostgrest):
else:
assert " 500 " in output[0]
assert "canceling statement due to statement timeout" in output[1]


def test_function_setting_statement_timeout_fails(defaultenv):
"statement that takes three seconds to execute should fail with one second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/one_sec_timeout")

assert response.status_code == 500
assert (
response.text
== '{"code":"57014","details":null,"hint":null,"message":"canceling statement due to statement timeout"}'
)


def test_function_setting_statement_timeout_passes(defaultenv):
"statement that takes three seconds to execute should succeed with four second timeout"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/four_sec_timeout")

assert response.status_code == 204


def test_function_setting_log_min_duration_sample(defaultenv):
"check function setting log_min_duration_sample is applied"

with run(env=defaultenv) as postgrest:
response = postgrest.session.post("/rpc/log_min_duration_test")

assert response.text == '"5s"'

0 comments on commit f354c88

Please sign in to comment.