Skip to content

Commit

Permalink
feat: E-matching procedure for the grind tactic (#6481)
Browse files Browse the repository at this point in the history
This PR implements E-matching for the (WIP) `grind` tactic. We still
need to finalize and internalize the new instances.
  • Loading branch information
leodemoura authored Dec 31, 2024
1 parent 32dc165 commit 2c87905
Show file tree
Hide file tree
Showing 15 changed files with 404 additions and 78 deletions.
23 changes: 14 additions & 9 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,23 @@ import Lean.Meta.Tactic.Grind.EMatchTheorem

namespace Lean

/-! Trace options for `grind` users -/
builtin_initialize registerTraceClass `grind
builtin_initialize registerTraceClass `grind.eq
builtin_initialize registerTraceClass `grind.assert
builtin_initialize registerTraceClass `grind.eqc
builtin_initialize registerTraceClass `grind.internalize
builtin_initialize registerTraceClass `grind.ematch
builtin_initialize registerTraceClass `grind.ematch.pattern
builtin_initialize registerTraceClass `grind.ematch.instance
builtin_initialize registerTraceClass `grind.issues
builtin_initialize registerTraceClass `grind.add
builtin_initialize registerTraceClass `grind.pre
builtin_initialize registerTraceClass `grind.simp

/-! Trace options for `grind` developers -/
builtin_initialize registerTraceClass `grind.debug
builtin_initialize registerTraceClass `grind.debug.proofs
builtin_initialize registerTraceClass `grind.simp
builtin_initialize registerTraceClass `grind.congr
builtin_initialize registerTraceClass `grind.proof
builtin_initialize registerTraceClass `grind.proof.detail
builtin_initialize registerTraceClass `grind.pattern
builtin_initialize registerTraceClass `grind.internalize
builtin_initialize registerTraceClass `grind.debug.congr
builtin_initialize registerTraceClass `grind.debug.pre
builtin_initialize registerTraceClass `grind.debug.proof
builtin_initialize registerTraceClass `grind.debug.proj

end Lean
5 changes: 2 additions & 3 deletions src/Lean/Meta/Tactic/Grind/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ private partial def updateMT (root : Expr) : GoalM Unit := do
updateMT parent

private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
trace[grind.eq] "{lhs} {if isHEq then "" else "="} {rhs}"
let lhsNode ← getENode lhs
let rhsNode ← getENode rhs
if isSameExpr lhsNode.root rhsNode.root then
-- `lhs` and `rhs` are already in the same equivalence class.
trace[grind.debug] "{← ppENodeRef lhs} and {← ppENodeRef rhs} are already in the same equivalence class"
return ()
trace[grind.eqc] "{lhs} {if isHEq then "" else "="} {rhs}"
let lhsRoot ← getENode lhsNode.root
let rhsRoot ← getENode rhsNode.root
let mut valueInconsistency := false
Expand Down Expand Up @@ -195,15 +195,14 @@ def addHEq (lhs rhs proof : Expr) : GoalM Unit := do
Adds a new `fact` justified by the given proof and using the given generation.
-/
def add (fact : Expr) (proof : Expr) (generation := 0) : GoalM Unit := do
trace[grind.add] "{proof} : {fact}"
trace[grind.assert] "{fact}"
if (← isInconsistent) then return ()
resetNewEqs
let_expr Not p := fact
| go fact false
go p true
where
go (p : Expr) (isNeg : Bool) : GoalM Unit := do
trace[grind.add] "isNeg: {isNeg}, {p}"
match_expr p with
| Eq _ lhs rhs => goEq p lhs rhs isNeg false
| HEq _ lhs _ rhs => goEq p lhs rhs isNeg true
Expand Down
241 changes: 241 additions & 0 deletions src/Lean/Meta/Tactic/Grind/EMatch.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Internalize

namespace Lean.Meta.Grind

/-- Returns maximum term generation that is considered during ematching -/
private def getMaxGeneration : GoalM Nat := do
return 10000 -- TODO

/-- Returns `true` if the maximum number of instances has been reached. -/
private def checkMaxInstancesExceeded : GoalM Bool := do
return false -- TODO

namespace EMatch
/-! This module implements a simple E-matching procedure as a backtracking search. -/

/-- We represent an `E-matching` problem as a list of constraints. -/
inductive Cnstr where
| /-- Matches pattern `pat` with term `e` -/
«match» (pat : Expr) (e : Expr)
| /-- This constraint is used to encode multi-patterns. -/
«continue» (pat : Expr)
deriving Inhabited

/--
Internal "marker" for representing unassigned elemens in the `assignment` field.
This is a small hack to avoid one extra level of indirection by using `Option Expr` at `assignment`.
-/
private def unassigned : Expr := mkConst (Name.mkSimple "[grind_unassigned]")

private def assignmentToMessageData (assignment : Array Expr) : Array MessageData :=
assignment.reverse.map fun e =>
if isSameExpr e unassigned then m!"_" else m!"{e}"

/--
Choice point for the backtracking search.
The state of the procedure contains a stack of choices.
-/
structure Choice where
/-- Contraints to be processed. -/
cnstrs : List Cnstr
/-- Maximum term generation found so far. -/
gen : Nat
/-- Partial assignment so far. Recall that pattern variables are encoded as de-Bruijn variables. -/
assignment : Array Expr
deriving Inhabited

/-- Theorem instances found so far. We only internalize them after we complete a full round of E-matching. -/
structure TheoremInstance where
prop : Expr
proof : Expr
generation : Nat
deriving Inhabited

/-- Context for the E-matching monad. -/
structure Context where
/-- `useMT` is `true` if we are using the mod-time optimization. It is always set to false for new `EMatchTheorem`s. -/
useMT : Bool := true
/-- `EMatchTheorem` being processed. -/
thm : EMatchTheorem := default
deriving Inhabited

/-- State for the E-matching monad -/
structure State where
/-- Choices that still have to be processed. -/
choiceStack : List Choice := []
newInstances : PArray TheoremInstance := {}
deriving Inhabited

abbrev M := ReaderT Context $ StateRefT State GoalM

def M.run' (x : M α) : GoalM α :=
x {} |>.run' {}

/--
Assigns `bidx := e` in `c`. If `bidx` is already assigned in `c`, we check whether
`e` and `c.assignment[bidx]` are in the same equivalence class.
This function assumes `bidx < c.assignment.size`.
Recall that we initialize the assignment array with the number of theorem parameters.
-/
private def assign? (c : Choice) (bidx : Nat) (e : Expr) : OptionT GoalM Choice := do
if h : bidx < c.assignment.size then
let v := c.assignment[bidx]
if isSameExpr v unassigned then
return { c with assignment := c.assignment.set bidx e }
else
guard (← isEqv v e)
return c
else
-- `Choice` was not properly initialized
unreachable!

/--
Returns `true` if the function `pFn` of a pattern is equivalent to the function `eFn`.
Recall that we ignore universe levels in patterns.
-/
private def eqvFunctions (pFn eFn : Expr) : Bool :=
(pFn.isFVar && pFn == eFn)
|| (pFn.isConst && eFn.isConstOf pFn.constName!)

/--
Matches arguments of pattern `p` with term `e`. Returns `some` if successful,
and `none` otherwise. It may update `c`s assignment and list of contraints to be
processed.
-/
private partial def matchArgs? (c : Choice) (p : Expr) (e : Expr) : OptionT GoalM Choice := do
if !p.isApp then return c -- Done
let pArg := p.appArg!
let eArg := e.appArg!
let goFn c := matchArgs? c p.appFn! e.appFn!
if isPatternDontCare pArg then
goFn c
else if pArg.isBVar then
goFn (← assign? c pArg.bvarIdx! eArg)
else if let some pArg := groundPattern? pArg then
guard (← isEqv pArg eArg)
goFn c
else
goFn { c with cnstrs := .match pArg eArg :: c.cnstrs }

/--
Matches pattern `p` with term `e` with respect to choice `c`.
We traverse the equivalence class of `e` looking for applications compatible with `p`.
For each candidate application, we match the arguments and may update `c`s assignments and contraints.
We add the updated choices to the choice stack.
-/
private partial def processMatch (c : Choice) (p : Expr) (e : Expr) : M Unit := do
let maxGeneration ← getMaxGeneration
let pFn := p.getAppFn
let numArgs := p.getAppNumArgs
let mut curr := e
repeat
let n ← getENode curr
if n.generation <= maxGeneration
-- uses heterogeneous equality or is the root of its congruence class
&& (n.heqProofs || isSameExpr curr n.cgRoot)
&& eqvFunctions pFn curr.getAppFn
&& curr.getAppNumArgs == numArgs then
if let some c ← matchArgs? c p curr |>.run then
let gen := n.generation
let c := { c with gen := Nat.max gen c.gen }
modify fun s => { s with choiceStack := c :: s.choiceStack }
curr ← getNext curr
if isSameExpr curr e then break

/-- Processes `continue` contraint used to implement multi-patterns. -/
private def processContinue (c : Choice) (p : Expr) : M Unit := do
let some apps := (← getThe Goal).appMap.find? p.toHeadIndex
| return ()
let maxGeneration ← getMaxGeneration
for app in apps do
let n ← getENode app
if n.generation <= maxGeneration
&& (n.heqProofs || isSameExpr n.cgRoot app) then
if let some c ← matchArgs? c p app |>.run then
let gen := n.generation
let c := { c with gen := Nat.max gen c.gen }
modify fun s => { s with choiceStack := c :: s.choiceStack }

private partial def instantiateTheorem (c : Choice) : M Unit := do
trace[grind.ematch.instance] "{(← read).thm.origin.key} : {assignmentToMessageData c.assignment}"
-- TODO
return ()

/-- Process choice stack until we don't have more choices to be processed. -/
private partial def processChoices : M Unit := do
unless (← get).choiceStack.isEmpty do
let c ← modifyGet fun s : State => (s.choiceStack.head!, { s with choiceStack := s.choiceStack.tail! })
match c.cnstrs with
| [] => instantiateTheorem c
| .match p e :: cnstrs => processMatch { c with cnstrs } p e
| .continue p :: cnstrs => processContinue { c with cnstrs } p
processChoices

private def main (p : Expr) (cnstrs : List Cnstr) : M Unit := do
let some apps := (← getThe Goal).appMap.find? p.toHeadIndex
| return ()
let numParams := (← read).thm.numParams
let assignment := mkArray numParams unassigned
let useMT := (← read).useMT
let gmt := (← getThe Goal).gmt
for app in apps do
let n ← getENode app
if (n.heqProofs || isSameExpr n.cgRoot app) &&
(!useMT || n.mt == gmt) then
if let some c ← matchArgs? { cnstrs, assignment, gen := n.generation } p app |>.run then
modify fun s => { s with choiceStack := [c] }
processChoices

def ematchTheorem (thm : EMatchTheorem) : M Unit := do
withReader (fun ctx => { ctx with thm }) do
let ps := thm.patterns
match ps, (← read).useMT with
| [p], _ => main p []
| p::ps, false => main p (ps.map (.continue ·))
| _::_, true => tryAll ps []
| _, _ => unreachable!
where
/--
When using the mod-time optimization with multi-patterns,
we must start ematching at each different pattern. That is,
if we have `[p₁, p₂, p₃]`, we must execute
- `main p₁ [.continue p₂, .continue p₃]`
- `main p₂ [.continue p₁, .continue p₃]`
- `main p₃ [.continue p₁, .continue p₂]`
-/
tryAll (ps : List Expr) (cs : List Cnstr) : M Unit := do
match ps with
| [] => return ()
| p::ps =>
main p (cs.reverse ++ (ps.map (.continue ·)))
tryAll ps (.continue p :: cs)

def ematchTheorems (thms : PArray EMatchTheorem) : M Unit := do
thms.forM ematchTheorem

def internalizeNewInstances : M Unit := do
-- TODO
return ()

end EMatch

open EMatch

/-- Performs one round of E-matching, and internalizes new instances. -/
def ematch : GoalM Unit := do
let go (thms newThms : PArray EMatchTheorem) : EMatch.M Unit := do
withReader (fun ctx => { ctx with useMT := true }) <| ematchTheorems thms
withReader (fun ctx => { ctx with useMT := false }) <| ematchTheorems newThms
internalizeNewInstances
unless (← checkMaxInstancesExceeded) do
go (← get).thms (← get).newThms |>.run'
modify fun s => { s with thms := s.thms ++ s.newThms, newThms := {}, gmt := s.gmt + 1 }

end Lean.Meta.Grind
31 changes: 18 additions & 13 deletions src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,21 @@ private def getPatternFn? (pattern : Expr) : Option Expr :=
| f@(.fvar _) => some f
| _ => none

private structure PatternFunInfo where
instImplicitMask : Array Bool
typeMask : Array Bool
/--
Returns a bit-mask `mask` s.t. `mask[i]` is true if the the corresponding argument is
- a type or type former, or
- a proof, or
- an instance implicit argument
private def getPatternFunInfo (f : Expr) (numArgs : Nat) : MetaM PatternFunInfo := do
When `mask[i]`, we say the corresponding argument is a "support" argument.
-/
private def getPatternFunMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool) := do
forallBoundedTelescope (← inferType f) numArgs fun xs _ => do
let typeMask ← xs.mapM fun x => isTypeFormer x
let instImplicitMask ← xs.mapM fun x => return (← x.fvarId!.getDecl).binderInfo matches .instImplicit
return { typeMask, instImplicitMask }
xs.mapM fun x => do
if (← isTypeFormer x <||> isProof x) then
return true
else
return (← x.fvarId!.getDecl).binderInfo matches .instImplicit

private partial def go (pattern : Expr) (root := false) : M Expr := do
if root && !pattern.hasLooseBVars then
Expand All @@ -143,25 +149,24 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do
assert! f.isConst || f.isFVar
saveSymbol f.toHeadIndex
let mut args := pattern.getAppArgs
let { instImplicitMask, typeMask } ← getPatternFunInfo f args.size
let supportMask ← getPatternFunMask f args.size
for i in [:args.size] do
let arg := args[i]!
let isType := typeMask[i]?.getD false
let isInstImplicit := instImplicitMask[i]?.getD false
let isSupport := supportMask[i]?.getD false
let arg ← if !arg.hasLooseBVars then
if arg.hasMVar then
pure dontCare
else
pure <| mkGroundPattern arg
else match arg with
| .bvar idx =>
if (isType || isInstImplicit) && (← foundBVar idx) then
if isSupport && (← foundBVar idx) then
pure dontCare
else
saveBVar idx
pure arg
| _ =>
if isType || isInstImplicit then
if isSupport then
pure dontCare
else if let some _ := getPatternFn? arg then
go arg
Expand Down Expand Up @@ -305,7 +310,7 @@ def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr)
let proof := mkConst declName us
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
assert! symbols.all fun s => s matches .const _
trace[grind.pattern] "{declName}: {patterns.map ppPattern}"
trace[grind.ematch.pattern] "{declName}: {patterns.map ppPattern}"
if let .missing pos ← checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
Expand Down
6 changes: 3 additions & 3 deletions src/Lean/Meta/Tactic/Grind/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def addCongrTable (e : Expr) : GoalM Unit := do
unless (← hasSameType f g) do
trace[grind.issues] "found congruence between{indentExpr e}\nand{indentExpr e'}\nbut functions have different types"
return ()
trace[grind.congr] "{e} = {e'}"
trace[grind.debug.congr] "{e} = {e'}"
pushEqHEq e e' congrPlaceholderProof
let node ← getENode e
setENode e { node with cgRoot := e' }
Expand Down Expand Up @@ -59,11 +59,11 @@ private partial def activateTheoremPatterns (fName : Name) (generation : Nat) :
let thm := { thm with symbols }
match symbols with
| [] =>
trace[grind.pattern] "activated `{thm.origin.key}`"
let thm := { thm with patterns := (← thm.patterns.mapM (internalizePattern · generation)) }
trace[grind.ematch] "activated `{thm.origin.key}`, {thm.patterns.map ppPattern}"
modify fun s => { s with newThms := s.newThms.push thm }
| _ =>
trace[grind.pattern] "reinsert `{thm.origin.key}`"
trace[grind.ematch] "reinsert `{thm.origin.key}`"
modify fun s => { s with thmMap := s.thmMap.insert thm }

partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
Expand Down
Loading

0 comments on commit 2c87905

Please sign in to comment.