Skip to content

Commit

Permalink
refactor: WF.Eqns: rewrite fix without duplicating F (#6859)
Browse files Browse the repository at this point in the history
This PR changes how WF.Eqns unfolds the fixpoint. Instead of delta'ing
until we have `fix`, and then blindly applying `fix_eq`, we delta one
step less and preserve the function on the right hand side. This leads
to smaller terms in the next step, so easier to debug, possibly faster,
possibly more robust.
  • Loading branch information
nomeata authored Jan 30, 2025
1 parent dc445d7 commit cd62b8c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 24 deletions.
48 changes: 24 additions & 24 deletions src/Lean/Elab/PreDefinition/WF/Eqns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import Lean.Elab.PreDefinition.Basic
import Lean.Elab.PreDefinition.Eqns
import Lean.Meta.ArgsPacker.Basic
import Init.Data.Array.Basic
import Init.Internal.Order.Basic

namespace Lean.Elab.WF
open Meta
Expand All @@ -23,35 +22,33 @@ structure EqnInfo extends EqnInfoCore where
argsPacker : ArgsPacker
deriving Inhabited

private partial def deltaLHSUntilFix (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
let target ← mvarId.getType'
let some (_, lhs, _) := target.eq? | throwTacticEx `deltaLHSUntilFix mvarId "equality expected"
if lhs.isAppOf ``WellFounded.fix then
return mvarId
else if lhs.isAppOf ``Order.fix then
return mvarId
else
deltaLHSUntilFix (← deltaLHS mvarId)

private def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
let target ← mvarId.getType'
let some (_, lhs, rhs) := target.eq? | unreachable!
let h ←
if lhs.isAppOf ``WellFounded.fix then
pure <| mkAppN (mkConst ``WellFounded.fix_eq lhs.getAppFn.constLevels!) lhs.getAppArgs
else if lhs.isAppOf ``Order.fix then
let x := lhs.getAppArgs.back!
let args := lhs.getAppArgs.pop
mkAppM ``congrFun #[mkAppN (mkConst ``Order.fix_eq lhs.getAppFn.constLevels!) args, x]
else
throwTacticEx `rwFixEq mvarId "expected fixed-point application"
let some (_, _, lhsNew) := (← inferType h).eq? | unreachable!

-- lhs should be an application of the declNameNonrec, which unfolds to an
-- application of fix in one step
let some lhs' ← delta? lhs | throwError "rwFixEq: cannot delta-reduce {lhs}"
let_expr WellFounded.fix _α _C _r _hwf F x := lhs'
| throwTacticEx `rwFixEq mvarId "expected saturated fixed-point application in {lhs'}"
let h := mkAppN (mkConst ``WellFounded.fix_eq lhs'.getAppFn.constLevels!) lhs'.getAppArgs

-- We used to just rewrite with `fix_eq` and continue with whatever RHS that produces, but that
-- would include more copies of `fix` resulting in large and confusing terms.
-- Instead we manually construct the new term in terms of the current functions,
-- which should be headed by the `declNameNonRec`, and should be defeq to the expected type

-- if lhs == e x and lhs' == fix .., then lhsNew := e x = F x (fun y _ => e y)
let ftype := (← inferType (mkApp F x)).bindingDomain!
let f' ← forallBoundedTelescope ftype (some 2) fun ys _ => do
mkLambdaFVars ys (.app lhs.appFn! ys[0]!)
let lhsNew := mkApp2 F x f'
let targetNew ← mkEq lhsNew rhs
let mvarNew ← mkFreshExprSyntheticOpaqueMVar targetNew
mvarId.assign (← mkEqTrans h mvarNew)
return mvarNew.mvarId!

private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
private partial def mkProof (declName declNameNonRec : Name) (type : Expr) : MetaM Expr := do
trace[Elab.definition.wf.eqns] "proving: {type}"
withNewMCtxDepth do
let main ← mkFreshExprSyntheticOpaqueMVar type
Expand Down Expand Up @@ -83,7 +80,10 @@ private partial def mkProof (declName : Name) (type : Expr) : MetaM Expr := do
-- LHS (introduced in 096e4eb), but it seems that code path was never used,
-- so #3133 removed it again (and can be recovered from there if this was premature).
throwError "failed to generate equational theorem for '{declName}'\n{MessageData.ofGoal mvarId}"
go (← rwFixEq (← deltaLHSUntilFix mvarId))

let mvarId ← if declName != declNameNonRec then deltaLHS mvarId else pure mvarId
let mvarId ← rwFixEq mvarId
go mvarId
instantiateMVars main

def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=
Expand All @@ -101,7 +101,7 @@ def mkEqns (declName : Name) (info : EqnInfo) : MetaM (Array Name) :=
trace[Elab.definition.wf.eqns] "{eqnTypes[i]}"
let name := (Name.str baseName eqnThmSuffixBase).appendIndexAfter (i+1)
thmNames := thmNames.push name
let value ← mkProof declName type
let value ← mkProof declName info.declNameNonRec type
let (type, value) ← removeUnusedEqnHypotheses type value
addDecl <| Declaration.thmDecl {
name, type, value
Expand Down
6 changes: 6 additions & 0 deletions tests/lean/run/simpDiag.lean
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ info: [simp] Diagnostics
[simp] ack.eq_1 ↦ 768, succeeded: 768
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
info: [diag] Diagnostics
[kernel] unfolded declarations (max: 29, num: 2):
[kernel] Nat.casesOn ↦ 29
[kernel] Nat.rec ↦ 29
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
error: tactic 'simp' failed, nested error:
maximum recursion depth has been reached
use `set_option maxRecDepth <num>` to increase limit
Expand Down

0 comments on commit cd62b8c

Please sign in to comment.