Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add [grind ←=] attribute #6702

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Init/Grind/Tactics.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ namespace Lean.Parser.Attr
syntax grindEq := "="
syntax grindEqBoth := atomic("_" "=" "_")
syntax grindEqRhs := atomic("=" "_")
syntax grindEqBwd := atomic("←" "=")
syntax grindBwd := "←"
syntax grindFwd := "→"

syntax grindThmMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindBwd <|> grindFwd
syntax grindThmMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd

syntax (name := grind) "grind" (grindThmMod)? : attr

Expand Down
3 changes: 3 additions & 0 deletions src/Init/Grind/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def doNotSimp {α : Sort u} (a : α) : α := a
/-- Gadget for representing offsets `t+k` in patterns. -/
def offset (a b : Nat) : Nat := a + b

/-- Gadget for representing `a = b` in patterns for backward propagation. -/
def eqBwdPattern (a b : α) : Prop := a = b

/--
Gadget for annotating the equalities in `match`-equations conclusions.
`_origin` is the term used to instantiate the `match`-equation using E-matching.
Expand Down
34 changes: 33 additions & 1 deletion src/Lean/Meta/Tactic/Grind/EMatch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,44 @@ private def main (p : Expr) (cnstrs : List Cnstr) : M Unit := do
modify fun s => { s with choiceStack := [c] }
processChoices

/--
Entry point for matching `lhs ←= rhs` patterns.
It traverses disequalities `a = b`, and tries to solve two matching problems:
1- match `lhs` with `a` and `rhs` with `b`
2- match `lhs` with `b` and `rhs` with `a`
-/
private def matchEqBwdPat (p : Expr) : M Unit := do
let_expr Grind.eqBwdPattern pα plhs prhs := p | return ()
let numParams := (← read).thm.numParams
let assignment := mkArray numParams unassigned
let useMT := (← read).useMT
let gmt := (← getThe Goal).gmt
let false ← getFalseExpr
let mut curr := false
repeat
if (← checkMaxInstancesExceeded) then return ()
let n ← getENode curr
if (n.heqProofs || n.isCongrRoot) &&
(!useMT || n.mt == gmt) then
let_expr Eq α lhs rhs := n.self | pure ()
if (← isDefEq α pα) then
let c₀ : Choice := { cnstrs := [], assignment, gen := n.generation }
let go (lhs rhs : Expr) : M Unit := do
let some c₁ ← matchArg? c₀ plhs lhs |>.run | return ()
let some c₂ ← matchArg? c₁ prhs rhs |>.run | return ()
modify fun s => { s with choiceStack := [c₂] }
processChoices
go lhs rhs
go rhs lhs
if isSameExpr n.next false then return ()
curr := n.next

def ematchTheorem (thm : EMatchTheorem) : M Unit := do
if (← checkMaxInstancesExceeded) then return ()
withReader (fun ctx => { ctx with thm }) do
let ps := thm.patterns
match ps, (← read).useMT with
| [p], _ => main p []
| [p], _ => if isEqBwdPattern p then matchEqBwdPat p else main p []
| p::ps, false => main p (ps.map (.continue ·))
| _::_, true => tryAll ps []
| _, _ => unreachable!
Expand Down
31 changes: 29 additions & 2 deletions src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ def isOffsetPattern? (pat : Expr) : Option (Expr × Nat) := Id.run do
let .lit (.natVal k) := k | none
return some (pat, k)

def mkEqBwdPattern (u : List Level) (α : Expr) (lhs rhs : Expr) : Expr :=
mkApp3 (mkConst ``Grind.eqBwdPattern u) α lhs rhs

def isEqBwdPattern (e : Expr) : Bool :=
e.isAppOfArity ``Grind.eqBwdPattern 3

def isEqBwdPattern? (e : Expr) : Option (Expr × Expr) :=
let_expr Grind.eqBwdPattern _ lhs rhs := e
| none
some (lhs, rhs)

def preprocessPattern (pat : Expr) (normalizePattern := true) : MetaM Expr := do
let pat ← instantiateMVars pat
let pat ← unfoldReducible pat
Expand Down Expand Up @@ -314,7 +325,8 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do
let some f := getPatternFn? pattern
| throwError "invalid pattern, (non-forbidden) application expected{indentExpr pattern}"
assert! f.isConst || f.isFVar
saveSymbol f.toHeadIndex
unless f.isConstOf ``Grind.eqBwdPattern do
saveSymbol f.toHeadIndex
let mut args := pattern.getAppArgs.toVector
let supportMask ← getPatternSupportMask f args.size
for h : i in [:args.size] do
Expand Down Expand Up @@ -481,6 +493,8 @@ Pattern variables are represented using de Bruijn indices.
-/
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) : MetaM EMatchTheorem := do
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
if symbols.isEmpty then
throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}"
if let .missing pos ← checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
Expand Down Expand Up @@ -523,6 +537,14 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof :
return (xs.size, pats)
mkEMatchTheoremCore origin levelParams numParams proof patterns

def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
let_expr f@Eq α lhs rhs := type
| throwError "invalid E-matching `≠` theorem, conclusion must be an equality{indentExpr type}"
let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs)
return (xs.size, [pat.abstract xs])
mkEMatchTheoremCore origin levelParams numParams proof patterns

/--
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
creates an E-matching pattern for it using `addEMatchTheorem n [lhs]`
Expand Down Expand Up @@ -552,13 +574,14 @@ def getEMatchTheorems : CoreM EMatchTheorems :=
return ematchTheoremsExt.getState (← getEnv)

inductive TheoremKind where
| eqLhs | eqRhs | eqBoth | fwd | bwd | default
| eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default
deriving Inhabited, BEq, Repr

private def TheoremKind.toAttribute : TheoremKind → String
| .eqLhs => "[grind =]"
| .eqRhs => "[grind =_]"
| .eqBoth => "[grind _=_]"
| .eqBwd => "[grind ←=]"
| .fwd => "[grind →]"
| .bwd => "[grind ←]"
| .default => "[grind]"
Expand All @@ -567,6 +590,7 @@ private def TheoremKind.explainFailure : TheoremKind → String
| .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion"
| .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion"
| .eqBoth => unreachable! -- eqBoth is a macro
| .eqBwd => "failed to use theorem's conclusion as a pattern"
| .fwd => "failed to find patterns in the antecedents of the theorem"
| .bwd => "failed to find patterns in the theorem's conclusion"
| .default => "failed to find patterns"
Expand Down Expand Up @@ -656,6 +680,8 @@ def mkEMatchTheoremWithKind? (origin : Origin) (levelParams : Array Name) (proof
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true))
else if kind == .eqRhs then
return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false))
else if kind == .eqBwd then
return (← mkEMatchEqBwdTheoremCore origin levelParams proof)
let type ← inferType proof
forallTelescopeReducing type fun xs type => do
let searchPlaces ← match kind with
Expand Down Expand Up @@ -687,6 +713,7 @@ def getTheoremKindCore (stx : Syntax) : CoreM TheoremKind := do
| `(Parser.Attr.grindThmMod| ←) => return .bwd
| `(Parser.Attr.grindThmMod| =_) => return .eqRhs
| `(Parser.Attr.grindThmMod| _=_) => return .eqBoth
| `(Parser.Attr.grindThmMod| ←=) => return .eqBwd
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"

/-- Return theorem kind for `stx` of the form `(Attr.grindThmMod)?` -/
Expand Down
55 changes: 55 additions & 0 deletions tests/lean/run/grind_eq_bwd.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
theorem dummy (x : Nat) : x = x :=
rfl

/--
error: invalid pattern for `dummy`
[@Lean.Grind.eqBwdPattern `[Nat] #0 #0]
the pattern does not contain constant symbols for indexing
-/
#guard_msgs in
attribute [grind ←=] dummy

def α : Type := sorry
def inv : α → α := sorry
def mul : α → α → α := sorry
def one : α := sorry

theorem inv_eq {a b : α} (w : mul a b = one) : inv a = b := sorry

/--
info: [grind.ematch.pattern] inv_eq: [@Lean.Grind.eqBwdPattern `[α] (inv #2) #1]
-/
#guard_msgs in
set_option trace.grind.ematch.pattern true in
attribute [grind ←=] inv_eq

example {a b : α} (w : mul a b = one) : inv a = b := by
grind

structure S where
f : Bool → α
h : mul (f true) (f false) = one
h' : mul (f false) (f true) = one

attribute [grind =] S.h S.h'

example (s : S) : inv (s.f true) = s.f false := by
grind

example (s : S) : s.f false = inv (s.f true) := by
grind

example (s : S) : a = false → s.f a = inv (s.f true) := by
grind

example (s : S) : a ≠ s.f false → a = inv (s.f true) → False := by
grind

/--
info: [grind.ematch.instance] inv_eq: mul (s.f true) (s.f false) = one → inv (s.f true) = s.f false
[grind.ematch.instance] S.h: mul (s.f true) (s.f false) = one
-/
#guard_msgs (info) in
set_option trace.grind.ematch.instance true in
example (s : S) : inv (s.f true) = s.f false := by
grind
Loading