Skip to content

Commit

Permalink
feat: async modes for environment access
Browse files Browse the repository at this point in the history
  • Loading branch information
Kha committed Jan 31, 2025
1 parent 6865329 commit d7ed395
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 23 deletions.
166 changes: 146 additions & 20 deletions src/Lean/Environment.lean
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,10 @@ def AsyncConsts.add (aconsts : AsyncConsts) (aconst : AsyncConst) : AsyncConsts
def AsyncConsts.find? (aconsts : AsyncConsts) (declName : Name) : Option AsyncConst :=
aconsts.map.find? declName

/-- Checks whether the name of any constant in the collection is a prefix of `declName`. -/
def AsyncConsts.hasPrefix (aconsts : AsyncConsts) (declName : Name) : Bool :=
/-- Finds the constant in the collection that is a prefix of `declName`, if any. -/
def AsyncConsts.findPrefix? (aconsts : AsyncConsts) (declName : Name) : Option AsyncConst :=
-- as macro scopes are a strict suffix,
aconsts.normalizedTrie.findLongestPrefix? (privateToUserName declName.eraseMacroScopes) |>.isSome
aconsts.normalizedTrie.findLongestPrefix? (privateToUserName declName.eraseMacroScopes)

/--
Elaboration-specific extension of `Kernel.Environment` that adds tracking of asynchronously
Expand Down Expand Up @@ -463,6 +463,18 @@ private def modifyCheckedAsync (env : Environment) (f : Kernel.Environment → K
private def setCheckedSync (env : Environment) (newChecked : Kernel.Environment) : Environment :=
{ env with checked := .pure newChecked, checkedWithoutAsync := newChecked }

/--
Checks whether the given declaration name may potentially added, or have been added, to the current
environment branch, which is the case either if this is the main branch or if the declaration name
is a suffix (modulo privacy and hygiene information) of the top-level declaration name for which
this branch was created.
This function should always be checked before modifying an `AsyncMode.async` environment extension
to ensure `findStateAsync` will be able to find the modification from other branches.
-/
def asyncMayContain (env : Environment) (declName : Name) : Bool :=
env.asyncCtx?.all (·.mayContain declName)

@[extern "lean_elab_add_decl"]
private opaque addDeclCheck (env : Environment) (maxHeartbeats : USize) (decl : @& Declaration)
(cancelTk? : @& Option IO.CancelToken) : Except Kernel.Exception Environment
Expand Down Expand Up @@ -515,7 +527,7 @@ def addExtraName (env : Environment) (name : Name) : Environment :=

/-- Find base case: name did not match any asynchronous declaration. -/
private def findNoAsync (env : Environment) (n : Name) : Option ConstantInfo := do
if env.asyncConsts.hasPrefix n then
if let some _ := env.asyncConsts.findPrefix? n then
-- Constant generated in a different environment branch: wait for final kernel environment. Rare
-- case when only proofs are elaborated asynchronously as they are rarely inspected. Could be
-- optimized in the future by having the elaboration thread publish an (incremental?) map of
Expand Down Expand Up @@ -756,13 +768,76 @@ def instantiateValueLevelParams! (c : ConstantInfo) (ls : List Level) : Expr :=

end ConstantInfo

/--
Async access mode for environment extensions used in `EnvironmentExtension.get/set/modifyState`.
Depending on their specific uses, extensions may opt out of the strict `sync` access mode in order
to avoid blocking parallel elaboration and/or to optimize accesses. The access mode is set at
environment extension registration time but can be overriden at `EnvironmentExtension.getState` in
order to weaken it for specific accesses.
In all modes, the state stored into the `.olean` file for persistent environment extensions is the
result of `getState` called on the main environment branch at the end of the file, i.e. it
encompasses all modifications for all modes but `local`.
-/
inductive EnvExtension.AsyncMode where
/--
Default access mode, writing and reading the extension state to/from the full `checked`
environment. This mode ensures the observed state is identical independently of whether or how
parallel elaboration is used but `getState` will block on all prior environment branches by
waiting for `checked`. `setState` and `modifyState` do not block.
While a safe default, any extension that reasonably could be used in parallel elaboration contexts
should opt for a weaker mode to avoid blocking unless there is no way to access the correct state
without waiting for all prior environment branches, in which case its data management should be
restructured if at all possible.
-/
| sync
/--
Accesses only the state of the current environment branch. Modifications on other branches are not
visible and are ultimately discarded except for the main branch. Provides the fastest accessors,
will never block.
This mode is particularly suitable for extensions where state does not escape from lexical scopes
even without parallelism, e.g. `ScopedEnvExtension`s when setting local entries.
-/
| local
/--
Like `local` but panics when trying to modify the state on anything but the main environment
branch. For extensions that fulfill this requirement, all modes functionally coincide but this
is the safest and most efficient choice in that case, preventing accidental misuse.
This mode is suitable for extensions that are modified only at the command elaboration level
before any environment forks in the command, and in particular for extensions that are modified
only at the very beginning of the file.
-/
| mainOnly
/--
Accumulates modifications in the `checked` environment like `sync`, but `getState` will panic
instead of blocking. Instead `findStateAsync` should be used, which will access the state of the
environment branch corresponding to the passed declaration name, if any, or otherwise the state
of the current branch. In other words, at most one environment branch will be blocked on instead
of all prior branches. The local state can still be accessed by calling `getState` with mode
`local` explicitly.
This mode is suitable for extensions with map-like state where the key uniquely identifies the
top-level declaration where it could have been set, e.g. because the key on modification is always
the surrounding declaration's name. Any calls to `modifyState`/`setState` should assert
`asyncMayContain` with that key to ensure state is never accidentally stored in a branch where it
cannot be found by `findStateAsync`. In particular, this mode is closest to how the environment's
own constant map works which asserts the same predicate on modification and provides `findAsync?`
for block-avoiding access.
-/
| async
deriving Inhabited

/--
Environment extension, can only be generated by `registerEnvExtension` that allocates a unique index
for this extension into each environment's extension state's array.
-/
structure EnvExtension (σ : Type) where private mk ::
idx : Nat
mkInitial : IO σ
asyncMode : EnvExtension.AsyncMode
deriving Inhabited

namespace EnvExtension
Expand Down Expand Up @@ -817,23 +892,54 @@ def mkInitialExtStates : IO (Array EnvExtensionState) := do
let exts ← envExtensionsRef.get
exts.mapM fun ext => ext.mkInitial

-- TODO: store extension state in `checked`

def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
{ env with checkedWithoutAsync.extensions := unsafe ext.setStateImpl env.checkedWithoutAsync.extensions s }
/--
Applies the given function to the extension state. See `AsyncMode` for details on how modifications
from different environment branches are reconciled.
Note that in modes `sync` and `async`, `f` will be called twice, on the local and on the `checked`
state.
-/
def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σ → σ) : Environment :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
{ env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f }
match ext.asyncMode with
| .mainOnly =>
if let some asyncCtx := env.asyncCtx? then
let _ : Inhabited Environment := ⟨env⟩
panic! s!"Environment.modifyState: environment extension is marked as `mainOnly` but used in \
async context '{asyncCtx.declPrefix}'"
else
{ env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f }
| .local =>
{ env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f }
| _ =>
env.modifyCheckedAsync fun env =>
{ env with extensions := unsafe ext.modifyStateImpl env.extensions f }

/--
Sets the extension state to the given value. See `AsyncMode` for details on how modifications from
different environment branches are reconciled.
-/
def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment :=
inline <| modifyState ext env fun _ => s

-- `unsafe` fails to infer `Nonempty` here
unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ :=
private unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ)
(env : Environment) (asyncMode := ext.asyncMode) : σ :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
ext.getStateImpl env.checkedWithoutAsync.extensions
match asyncMode with
| .sync => ext.getStateImpl env.checked.get.extensions
| .async => panic! "EnvExtension.getState: called on `async` extension, use `findStateAsync` \
instead or pass `(asyncMode := .local)` to explicitly access local state"
| _ => ext.getStateImpl env.checkedWithoutAsync.extensions

/--
Returns the current extension state. See `AsyncMode` for details on how modifications from
different environment branches are reconciled. Panics if the extension is marked as `async`; see its
documentation for more details.
-/
@[implemented_by getStateUnsafe]
opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ
opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment)
(asyncMode := ext.asyncMode) : σ

end EnvExtension

Expand All @@ -844,15 +950,13 @@ end EnvExtension
Note that by default, extension state is *not* stored in .olean files and will not propagate across `import`s.
For that, you need to register a persistent environment extension. -/
def registerEnvExtension {σ : Type} (mkInitial : IO σ) : IO (EnvExtension σ) := do
def registerEnvExtension {σ : Type} (mkInitial : IO σ)
(asyncMode : EnvExtension.AsyncMode := .mainOnly) : IO (EnvExtension σ) := do
unless (← initializing) do
throw (IO.userError "failed to register environment, extensions can only be registered during initialization")
let exts ← EnvExtension.envExtensionsRef.get
let idx := exts.size
let ext : EnvExtension σ := {
idx := idx,
mkInitial := mkInitial,
}
let ext : EnvExtension σ := { idx, mkInitial, asyncMode }
-- safety: `EnvExtensionState` is opaque, so we can upcast to it
EnvExtension.envExtensionsRef.modify fun exts => exts.push (unsafe unsafeCast ext)
pure ext
Expand Down Expand Up @@ -953,7 +1057,8 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ)
namespace PersistentEnvExtension

def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
(ext.toEnvExtension.getState env).importedEntries.get! m
-- `importedEntries` is identical on all environment branches, so `local` is always sufficient
(ext.toEnvExtension.getState (asyncMode := .local) env).importedEntries.get! m

def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment :=
ext.toEnvExtension.modifyState env fun s =>
Expand All @@ -972,6 +1077,24 @@ def setState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : En
def modifyState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (f : σ → σ) : Environment :=
ext.toEnvExtension.modifyState env fun ps => { ps with state := f (ps.state) }

-- `unsafe` fails to infer `Nonempty` here
private unsafe def findStateAsyncUnsafe {α β σ : Type} [Inhabited σ]
(ext : PersistentEnvExtension α β σ) (env : Environment) (declPrefix : Name) : σ :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
if let some { exts? := some exts, .. } := env.asyncConsts.findPrefix? declPrefix then
ext.toEnvExtension.getStateImpl exts.get |>.state
else
ext.toEnvExtension.getStateImpl env.checkedWithoutAsync.extensions |>.state

/--
Returns the final extension state on the environment branch corresponding to the passed declaration
name, if any, or otherwise the state on the current branch. In other words, at most one environment
branch will be blocked on.
-/
@[implemented_by findStateAsyncUnsafe]
opaque findStateAsync {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ)
(env : Environment) (declPrefix : Name) : σ

end PersistentEnvExtension

builtin_initialize persistentEnvExtensionsRef : IO.Ref (Array (PersistentEnvExtension EnvExtensionEntry EnvExtensionEntry EnvExtensionState)) ← IO.mkRef #[]
Expand All @@ -983,11 +1106,12 @@ structure PersistentEnvExtensionDescr (α β σ : Type) where
addEntryFn : σ → β → σ
exportEntriesFn : σ → Array α
statsFn : σ → Format := fun _ => Format.nil
asyncMode : EnvExtension.AsyncMode := .mainOnly

unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ] (descr : PersistentEnvExtensionDescr α β σ) : IO (PersistentEnvExtension α β σ) := do
let pExts ← persistentEnvExtensionsRef.get
if pExts.any (fun ext => ext.name == descr.name) then throw (IO.userError s!"invalid environment extension, '{descr.name}' has already been used")
let ext ← registerEnvExtension do
let ext ← registerEnvExtension (asyncMode := descr.asyncMode) do
let initial ← descr.mkInitial
let s : PersistentEnvExtensionState α σ := {
importedEntries := #[],
Expand Down Expand Up @@ -1019,6 +1143,7 @@ structure SimplePersistentEnvExtensionDescr (α σ : Type) where
addEntryFn : σ → α → σ
addImportedFn : Array (Array α) → σ
toArrayFn : List α → Array α := fun es => es.toArray
asyncMode : EnvExtension.AsyncMode := .mainOnly

def registerSimplePersistentEnvExtension {α σ : Type} [Inhabited σ] (descr : SimplePersistentEnvExtensionDescr α σ) : IO (SimplePersistentEnvExtension α σ) :=
registerPersistentEnvExtension {
Expand All @@ -1029,6 +1154,7 @@ def registerSimplePersistentEnvExtension {α σ : Type} [Inhabited σ] (descr :
| (entries, s) => (e::entries, descr.addEntryFn s e),
exportEntriesFn := fun s => descr.toArrayFn s.1.reverse,
statsFn := fun s => format "number of local entries: " ++ format s.1.length
asyncMode := descr.asyncMode
}

namespace SimplePersistentEnvExtension
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Meta/LazyDiscrTree.lean
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,8 @@ def findImportMatches
let ngen ← getNGen
let (cNGen, ngen) := ngen.mkChild
setNGen ngen
let dummy : IO.Ref (Option (LazyDiscrTree α)) ← IO.mkRef none
let ref := @EnvExtension.getState _ ⟨dummy⟩ ext (←getEnv)
let _ : Inhabited (IO.Ref (Option (LazyDiscrTree α))) := ⟨← IO.mkRef none
let ref := ext.getState (←getEnv)
let importTree ← (←ref.get).getDM $ do
profileitM Exception "lazy discriminator import initialization" (←getOptions) $ do
let t ← createImportedDiscrTree (createTreeCtx cctx) cNGen (←getEnv) addEntry
Expand Down
3 changes: 2 additions & 1 deletion src/Lean/Parser/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,8 @@ builtin_initialize categoryParserFnRef : IO.Ref CategoryParserFn ← IO.mkRef fu
builtin_initialize categoryParserFnExtension : EnvExtension CategoryParserFn ← registerEnvExtension $ categoryParserFnRef.get

def categoryParserFn (catName : Name) : ParserFn := fun ctx s =>
categoryParserFnExtension.getState ctx.env catName ctx s
let fn := categoryParserFnExtension.getState ctx.env
fn catName ctx s

def categoryParser (catName : Name) (prec : Nat) : Parser where
fn := adaptCacheableContextFn ({ · with prec }) (withCacheFn catName (categoryParserFn catName))
Expand Down

0 comments on commit d7ed395

Please sign in to comment.