diff --git a/src/Init/Grind/Tactics.lean b/src/Init/Grind/Tactics.lean index 1ed632d93031..0a234217a1a3 100644 --- a/src/Init/Grind/Tactics.lean +++ b/src/Init/Grind/Tactics.lean @@ -11,10 +11,11 @@ namespace Lean.Parser.Attr syntax grindEq := "=" syntax grindEqBoth := atomic("_" "=" "_") syntax grindEqRhs := atomic("=" "_") +syntax grindEqBwd := atomic("←" "=") syntax grindBwd := "←" syntax grindFwd := "→" -syntax grindThmMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindBwd <|> grindFwd +syntax grindThmMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd syntax (name := grind) "grind" (grindThmMod)? : attr diff --git a/src/Init/Grind/Util.lean b/src/Init/Grind/Util.lean index 7001755ee455..34bc7ec2df73 100644 --- a/src/Init/Grind/Util.lean +++ b/src/Init/Grind/Util.lean @@ -21,6 +21,9 @@ def doNotSimp {α : Sort u} (a : α) : α := a /-- Gadget for representing offsets `t+k` in patterns. -/ def offset (a b : Nat) : Nat := a + b +/-- Gadget for representing `a = b` in patterns for backward propagation. -/ +def eqBwdPattern (a b : α) : Prop := a = b + /-- Gadget for annotating the equalities in `match`-equations conclusions. `_origin` is the term used to instantiate the `match`-equation using E-matching. diff --git a/src/Lean/Meta/Tactic/Grind/EMatch.lean b/src/Lean/Meta/Tactic/Grind/EMatch.lean index 7d01def92554..8341d44add46 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatch.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatch.lean @@ -310,12 +310,44 @@ private def main (p : Expr) (cnstrs : List Cnstr) : M Unit := do modify fun s => { s with choiceStack := [c] } processChoices +/-- +Entry point for matching `lhs ←= rhs` patterns. +It traverses disequalities `a = b`, and tries to solve two matching problems: +1- match `lhs` with `a` and `rhs` with `b` +2- match `lhs` with `b` and `rhs` with `a` +-/ +private def matchEqBwdPat (p : Expr) : M Unit := do + let_expr Grind.eqBwdPattern pα plhs prhs := p | return () + let numParams := (← read).thm.numParams + let assignment := mkArray numParams unassigned + let useMT := (← read).useMT + let gmt := (← getThe Goal).gmt + let false ← getFalseExpr + let mut curr := false + repeat + if (← checkMaxInstancesExceeded) then return () + let n ← getENode curr + if (n.heqProofs || n.isCongrRoot) && + (!useMT || n.mt == gmt) then + let_expr Eq α lhs rhs := n.self | pure () + if (← isDefEq α pα) then + let c₀ : Choice := { cnstrs := [], assignment, gen := n.generation } + let go (lhs rhs : Expr) : M Unit := do + let some c₁ ← matchArg? c₀ plhs lhs |>.run | return () + let some c₂ ← matchArg? c₁ prhs rhs |>.run | return () + modify fun s => { s with choiceStack := [c₂] } + processChoices + go lhs rhs + go rhs lhs + if isSameExpr n.next false then return () + curr := n.next + def ematchTheorem (thm : EMatchTheorem) : M Unit := do if (← checkMaxInstancesExceeded) then return () withReader (fun ctx => { ctx with thm }) do let ps := thm.patterns match ps, (← read).useMT with - | [p], _ => main p [] + | [p], _ => if isEqBwdPattern p then matchEqBwdPat p else main p [] | p::ps, false => main p (ps.map (.continue ·)) | _::_, true => tryAll ps [] | _, _ => unreachable! diff --git a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean index eb0e94ea40b1..d1cb334d351f 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean @@ -39,6 +39,17 @@ def isOffsetPattern? (pat : Expr) : Option (Expr × Nat) := Id.run do let .lit (.natVal k) := k | none return some (pat, k) +def mkEqBwdPattern (u : List Level) (α : Expr) (lhs rhs : Expr) : Expr := + mkApp3 (mkConst ``Grind.eqBwdPattern u) α lhs rhs + +def isEqBwdPattern (e : Expr) : Bool := + e.isAppOfArity ``Grind.eqBwdPattern 3 + +def isEqBwdPattern? (e : Expr) : Option (Expr × Expr) := + let_expr Grind.eqBwdPattern _ lhs rhs := e + | none + some (lhs, rhs) + def preprocessPattern (pat : Expr) (normalizePattern := true) : MetaM Expr := do let pat ← instantiateMVars pat let pat ← unfoldReducible pat @@ -314,7 +325,8 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do let some f := getPatternFn? pattern | throwError "invalid pattern, (non-forbidden) application expected{indentExpr pattern}" assert! f.isConst || f.isFVar - saveSymbol f.toHeadIndex + unless f.isConstOf ``Grind.eqBwdPattern do + saveSymbol f.toHeadIndex let mut args := pattern.getAppArgs.toVector let supportMask ← getPatternSupportMask f args.size for h : i in [:args.size] do @@ -481,6 +493,8 @@ Pattern variables are represented using de Bruijn indices. -/ def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) : MetaM EMatchTheorem := do let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns + if symbols.isEmpty then + throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing" trace[grind.ematch.pattern] "{MessageData.ofConst proof}: {patterns.map ppPattern}" if let .missing pos ← checkCoverage proof numParams bvarFound then let pats : MessageData := m!"{patterns.map ppPattern}" @@ -523,6 +537,14 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : return (xs.size, pats) mkEMatchTheoremCore origin levelParams numParams proof patterns +def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do + let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do + let_expr f@Eq α lhs rhs := type + | throwError "invalid E-matching `≠` theorem, conclusion must be an equality{indentExpr type}" + let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs) + return (xs.size, [pat.abstract xs]) + mkEMatchTheoremCore origin levelParams numParams proof patterns + /-- Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`, creates an E-matching pattern for it using `addEMatchTheorem n [lhs]` @@ -552,13 +574,14 @@ def getEMatchTheorems : CoreM EMatchTheorems := return ematchTheoremsExt.getState (← getEnv) inductive TheoremKind where - | eqLhs | eqRhs | eqBoth | fwd | bwd | default + | eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default deriving Inhabited, BEq, Repr private def TheoremKind.toAttribute : TheoremKind → String | .eqLhs => "[grind =]" | .eqRhs => "[grind =_]" | .eqBoth => "[grind _=_]" + | .eqBwd => "[grind ←=]" | .fwd => "[grind →]" | .bwd => "[grind ←]" | .default => "[grind]" @@ -567,6 +590,7 @@ private def TheoremKind.explainFailure : TheoremKind → String | .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion" | .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion" | .eqBoth => unreachable! -- eqBoth is a macro + | .eqBwd => "failed to use theorem's conclusion as a pattern" | .fwd => "failed to find patterns in the antecedents of the theorem" | .bwd => "failed to find patterns in the theorem's conclusion" | .default => "failed to find patterns" @@ -656,6 +680,8 @@ def mkEMatchTheoremWithKind? (origin : Origin) (levelParams : Array Name) (proof return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := true)) else if kind == .eqRhs then return (← mkEMatchEqTheoremCore origin levelParams proof (normalizePattern := true) (useLhs := false)) + else if kind == .eqBwd then + return (← mkEMatchEqBwdTheoremCore origin levelParams proof) let type ← inferType proof forallTelescopeReducing type fun xs type => do let searchPlaces ← match kind with @@ -687,6 +713,7 @@ def getTheoremKindCore (stx : Syntax) : CoreM TheoremKind := do | `(Parser.Attr.grindThmMod| ←) => return .bwd | `(Parser.Attr.grindThmMod| =_) => return .eqRhs | `(Parser.Attr.grindThmMod| _=_) => return .eqBoth + | `(Parser.Attr.grindThmMod| ←=) => return .eqBwd | _ => throwError "unexpected `grind` theorem kind: `{stx}`" /-- Return theorem kind for `stx` of the form `(Attr.grindThmMod)?` -/ diff --git a/tests/lean/run/grind_eq_bwd.lean b/tests/lean/run/grind_eq_bwd.lean new file mode 100644 index 000000000000..16214c52b7b1 --- /dev/null +++ b/tests/lean/run/grind_eq_bwd.lean @@ -0,0 +1,55 @@ +theorem dummy (x : Nat) : x = x := + rfl + +/-- +error: invalid pattern for `dummy` + [@Lean.Grind.eqBwdPattern `[Nat] #0 #0] +the pattern does not contain constant symbols for indexing +-/ +#guard_msgs in +attribute [grind ←=] dummy + +def α : Type := sorry +def inv : α → α := sorry +def mul : α → α → α := sorry +def one : α := sorry + +theorem inv_eq {a b : α} (w : mul a b = one) : inv a = b := sorry + +/-- +info: [grind.ematch.pattern] inv_eq: [@Lean.Grind.eqBwdPattern `[α] (inv #2) #1] +-/ +#guard_msgs in +set_option trace.grind.ematch.pattern true in +attribute [grind ←=] inv_eq + +example {a b : α} (w : mul a b = one) : inv a = b := by + grind + +structure S where + f : Bool → α + h : mul (f true) (f false) = one + h' : mul (f false) (f true) = one + +attribute [grind =] S.h S.h' + +example (s : S) : inv (s.f true) = s.f false := by + grind + +example (s : S) : s.f false = inv (s.f true) := by + grind + +example (s : S) : a = false → s.f a = inv (s.f true) := by + grind + +example (s : S) : a ≠ s.f false → a = inv (s.f true) → False := by + grind + +/-- +info: [grind.ematch.instance] inv_eq: mul (s.f true) (s.f false) = one → inv (s.f true) = s.f false +[grind.ematch.instance] S.h: mul (s.f true) (s.f false) = one +-/ +#guard_msgs (info) in +set_option trace.grind.ematch.instance true in +example (s : S) : inv (s.f true) = s.f false := by + grind