Skip to content

Commit

Permalink
feat: improve monadic Array lemmas (#6982)
Browse files Browse the repository at this point in the history
This PR improves some lemmas about monads and monadic operations on
Array/Vector, using @Rob23oa's work in
leanprover-community/batteries#1109, and
adding/generalizing some additional lemmas.
  • Loading branch information
kim-em authored Feb 7, 2025
1 parent 92f0d31 commit af385d7
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 44 deletions.
64 changes: 44 additions & 20 deletions src/Init/Control/Lawful/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,49 @@ import Init.RCases
import Init.ByCases

-- Mapping by a function with a left inverse is injective.
theorem map_inj_of_left_inverse [Applicative m] [LawfulApplicative m] {f : α → β}
(w : ∃ g : β → α, ∀ x, g (f x) = x) {x y : m α}
(h : f <$> x = f <$> y) : x = y := by
rcases w with ⟨g, w⟩
replace h := congrArg (g <$> ·) h
simpa [w] using h
theorem map_inj_of_left_inverse [Functor m] [LawfulFunctor m] {f : α → β}

This comment has been minimized.

Copy link
@alok

alok Feb 7, 2025

Contributor

This had me looking at the def of LawfulFunctor. I notice it doesn't extend Functor, instead taking it as an instance. Would it be better for LawfulFunctor to extend Functor since then it saves the need to write [Functor m] every time you have to write [LawfulFunctor m]?

(w : ∃ g : β → α, ∀ x, g (f x) = x) {x y : m α} :
f <$> x = f <$> y ↔ x = y := by
constructor
· intro h
rcases w with ⟨g, w⟩
replace h := congrArg (g <$> ·) h
simpa [w] using h
· rintro rfl
rfl

-- Mapping by an injective function is injective, as long as the domain is nonempty.
theorem map_inj_of_inj [Applicative m] [LawfulApplicative m] [Nonempty α] {f : α → β}
(w : ∀ x y, f x = f y → x = y) {x y : m α}
(h : f <$> x = f <$> y) : x = y := by
apply map_inj_of_left_inverse ?_ h
let ⟨a⟩ := ‹Nonempty α›
refine ⟨?_, ?_⟩
· intro b
by_cases p : ∃ a, f a = b
· exact Exists.choose p
· exact a
· intro b
simp only [exists_apply_eq_apply, ↓reduceDIte]
apply w
apply Exists.choose_spec (p := fun a => f a = f b)
@[simp] theorem map_inj_right_of_nonempty [Functor m] [LawfulFunctor m] [Nonempty α] {f : α → β}
(w : ∀ {x y}, f x = f y → x = y) {x y : m α} :
f <$> x = f <$> y ↔ x = y := by
constructor
· intro h
apply (map_inj_of_left_inverse ?_).mp h
let ⟨a⟩ := ‹Nonempty α›
refine ⟨?_, ?_⟩
· intro b
by_cases p : ∃ a, f a = b
· exact Exists.choose p
· exact a
· intro b
simp only [exists_apply_eq_apply, ↓reduceDIte]
apply w
apply Exists.choose_spec (p := fun a => f a = f b)
· rintro rfl
rfl

@[simp] theorem map_inj_right [Monad m] [LawfulMonad m]
{f : α → β} (h : ∀ {x y : α}, f x = f y → x = y) {x y : m α} :
f <$> x = f <$> y ↔ x = y := by
by_cases hempty : Nonempty α
· exact map_inj_right_of_nonempty h
· constructor
· intro h'
have (z : m α) : z = (do let a ← z; let b ← pure (f a); x) := by
conv => lhs; rw [← bind_pure z]
congr; funext a
exact (hempty ⟨a⟩).elim
rw [this x, this y]
rw [← bind_assoc, ← map_eq_pure_bind, h', map_eq_pure_bind, bind_assoc]
· intro h'
rw [h']
54 changes: 54 additions & 0 deletions src/Init/Data/Array/MapIdx.lean
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,57 @@ theorem mapIdx_eq_mkArray_iff {l : Array α} {f : Nat → α → β} {b : β} :
simp [List.mapIdx_reverse]

end Array

namespace List

theorem mapFinIdxM_toArray [Monad m] [LawfulMonad m] (l : List α)
(f : (i : Nat) → α → (h : i < l.length) → m β) :
l.toArray.mapFinIdxM f = toArray <$> l.mapFinIdxM f := by
let rec go (i : Nat) (acc : Array β) (inv : i + acc.size = l.length) :
Array.mapFinIdxM.map l.toArray f i acc.size inv acc
= toArray <$> mapFinIdxM.go l f (l.drop acc.size) acc
(by simp [Nat.sub_add_cancel (Nat.le.intro (Nat.add_comm _ _ ▸ inv))]) := by
match i with
| 0 =>
rw [Nat.zero_add] at inv
simp only [Array.mapFinIdxM.map, inv, drop_length, mapFinIdxM.go, map_pure]
| k + 1 =>
conv => enter [2, 2, 3]; rw [← getElem_cons_drop l acc.size (by omega)]
simp only [Array.mapFinIdxM.map, mapFinIdxM.go, _root_.map_bind]
congr; funext x
conv => enter [1, 4]; rw [← Array.size_push _ x]
conv => enter [2, 2, 3]; rw [← Array.size_push _ x]
refine go k (acc.push x) _
simp only [Array.mapFinIdxM, mapFinIdxM]
exact go _ #[] _

theorem mapIdxM_toArray [Monad m] [LawfulMonad m] (l : List α)
(f : Nat → α → m β) :
l.toArray.mapIdxM f = toArray <$> l.mapIdxM f := by
let rec go (bs : List α) (acc : Array β) (inv : bs.length + acc.size = l.length) :
mapFinIdxM.go l (fun i a h => f i a) bs acc inv = mapIdxM.go f bs acc := by
match bs with
| [] => simp only [mapFinIdxM.go, mapIdxM.go]
| x :: xs => simp only [mapFinIdxM.go, mapIdxM.go, go]
unfold Array.mapIdxM
rw [mapFinIdxM_toArray]
simp only [mapFinIdxM, mapIdxM]
rw [go]

end List

namespace Array

theorem toList_mapFinIdxM [Monad m] [LawfulMonad m] (l : Array α)
(f : (i : Nat) → α → (h : i < l.size) → m β) :
toList <$> l.mapFinIdxM f = l.toList.mapFinIdxM f := by
rw [List.mapFinIdxM_toArray]
simp only [Functor.map_map, id_map']

theorem toList_mapIdxM [Monad m] [LawfulMonad m] (l : Array α)
(f : Nat → α → m β) :
toList <$> l.mapIdxM f = l.toList.mapIdxM f := by
rw [List.mapIdxM_toArray]
simp only [Functor.map_map, id_map']

end Array
145 changes: 130 additions & 15 deletions src/Init/Data/Array/Monadic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,121 @@ theorem forIn_pure_yield_eq_foldl [Monad m] [LawfulMonad m]
cases l
simp

end Array

namespace List

theorem filterM_toArray [Monad m] [LawfulMonad m] (l : List α) (p : α → m Bool) :
l.toArray.filterM p = toArray <$> l.filterM p := by
simp only [Array.filterM, filterM, foldlM_toArray, bind_pure_comp, Functor.map_map]
conv => lhs; rw [← reverse_nil]
generalize [] = acc
induction l generalizing acc with simp
| cons x xs ih =>
congr; funext b
cases b
· simp only [Bool.false_eq_true, ↓reduceIte, pure_bind, cond_false]
exact ih acc
· simp only [↓reduceIte, ← reverse_cons, pure_bind, cond_true]
exact ih (x :: acc)

/-- Variant of `filterM_toArray` with a side condition for the stop position. -/
@[simp] theorem filterM_toArray' [Monad m] [LawfulMonad m] (l : List α) (p : α → m Bool) (w : stop = l.length) :
l.toArray.filterM p 0 stop = toArray <$> l.filterM p := by
subst w
rw [filterM_toArray]

theorem filterRevM_toArray [Monad m] [LawfulMonad m] (l : List α) (p : α → m Bool) :
l.toArray.filterRevM p = toArray <$> l.filterRevM p := by
simp [Array.filterRevM, filterRevM]
rw [← foldlM_reverse, ← foldlM_toArray, ← Array.filterM, filterM_toArray]
simp only [filterM, bind_pure_comp, Functor.map_map, reverse_toArray, reverse_reverse]

/-- Variant of `filterRevM_toArray` with a side condition for the start position. -/
@[simp] theorem filterRevM_toArray' [Monad m] [LawfulMonad m] (l : List α) (p : α → m Bool) (w : start = l.length) :
l.toArray.filterRevM p start 0 = toArray <$> l.filterRevM p := by
subst w
rw [filterRevM_toArray]

theorem filterMapM_toArray [Monad m] [LawfulMonad m] (l : List α) (f : α → m (Option β)) :
l.toArray.filterMapM f = toArray <$> l.filterMapM f := by
simp [Array.filterMapM, filterMapM]
conv => lhs; rw [← reverse_nil]
generalize [] = acc
induction l generalizing acc with simp [filterMapM.loop]
| cons x xs ih =>
congr; funext o
cases o
· simp only [pure_bind]; exact ih acc
· simp only [pure_bind]; rw [← List.reverse_cons]; exact ih _

/-- Variant of `filterMapM_toArray` with a side condition for the stop position. -/
@[simp] theorem filterMapM_toArray' [Monad m] [LawfulMonad m] (l : List α) (f : α → m (Option β)) (w : stop = l.length) :
l.toArray.filterMapM f 0 stop = toArray <$> l.filterMapM f := by
subst w
rw [filterMapM_toArray]

@[simp] theorem flatMapM_toArray [Monad m] [LawfulMonad m] (l : List α) (f : α → m (Array β)) :
l.toArray.flatMapM f = toArray <$> l.flatMapM (fun a => Array.toList <$> f a) := by
simp only [Array.flatMapM, bind_pure_comp, foldlM_toArray, flatMapM]
conv => lhs; arg 2; change [].reverse.flatten.toArray
generalize [] = acc
induction l generalizing acc with
| nil => simp only [foldlM_nil, flatMapM.loop, map_pure]
| cons x xs ih =>
simp only [foldlM_cons, bind_map_left, flatMapM.loop, _root_.map_bind]
congr; funext a
conv => lhs; rw [Array.toArray_append, ← flatten_concat, ← reverse_cons]
exact ih _

end List

namespace Array

@[congr] theorem filterM_congr [Monad m] {as bs : Array α} (w : as = bs)
{p : α → m Bool} {q : α → m Bool} (h : ∀ a, p a = q a) :
as.filterM p = bs.filterM q := by
subst w
simp [filterM, h]

@[congr] theorem filterRevM_congr [Monad m] {as bs : Array α} (w : as = bs)
{p : α → m Bool} {q : α → m Bool} (h : ∀ a, p a = q a) :
as.filterRevM p = bs.filterRevM q := by
subst w
simp [filterRevM, h]

@[congr] theorem filterMapM_congr [Monad m] {as bs : Array α} (w : as = bs)
{f : α → m (Option β)} {g : α → m (Option β)} (h : ∀ a, f a = g a) :
as.filterMapM f = bs.filterMapM g := by
subst w
simp [filterMapM, h]

@[congr] theorem flatMapM_congr [Monad m] {as bs : Array α} (w : as = bs)
{f : α → m (Array β)} {g : α → m (Array β)} (h : ∀ a, f a = g a) :
as.flatMapM f = bs.flatMapM g := by
subst w
simp [flatMapM, h]

theorem toList_filterM [Monad m] [LawfulMonad m] (a : Array α) (p : α → m Bool) :
toList <$> a.filterM p = a.toList.filterM p := by
rw [List.filterM_toArray]
simp only [Functor.map_map, id_map']

theorem toList_filterRevM [Monad m] [LawfulMonad m] (a : Array α) (p : α → m Bool) :
toList <$> a.filterRevM p = a.toList.filterRevM p := by
rw [List.filterRevM_toArray]
simp only [Functor.map_map, id_map']

theorem toList_filterMapM [Monad m] [LawfulMonad m] (a : Array α) (f : α → m (Option β)) :
toList <$> a.filterMapM f = a.toList.filterMapM f := by
rw [List.filterMapM_toArray]
simp only [Functor.map_map, id_map']

theorem toList_flatMapM [Monad m] [LawfulMonad m] (a : Array α) (f : α → m (Array β)) :
toList <$> a.flatMapM f = a.toList.flatMapM (fun a => toList <$> f a) := by
rw [List.flatMapM_toArray]
simp only [Functor.map_map, id_map']

/-! ### Recognizing higher order functions using a function that only depends on the value. -/

/--
Expand Down Expand Up @@ -260,20 +375,20 @@ and simplifies these to the function directly taking the value.
simp
rw [List.mapM_subtype hf]

-- Without `filterMapM_toArray` relating `filterMapM` on `List` and `Array` we can't prove this yet:
-- @[simp] theorem filterMapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }}
-- {f : { x // p x } → m (Option β)} {g : α → m (Option β)} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
-- l.filterMapM f = l.unattach.filterMapM g := by
-- rcases l with ⟨l⟩
-- simp
-- rw [List.filterMapM_subtype hf]

-- Without `flatMapM_toArray` relating `flatMapM` on `List` and `Array` we can't prove this yet:
-- @[simp] theorem flatMapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }}
-- {f : { x // p x } → m (Array β)} {g : α → m (Array β)} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
-- (l.flatMapM f) = l.unattach.flatMapM g := by
-- rcases l with ⟨l⟩
-- simp
-- rw [List.flatMapM_subtype hf]
@[simp] theorem filterMapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }}
{f : { x // p x } → m (Option β)} {g : α → m (Option β)} (hf : ∀ x h, f ⟨x, h⟩ = g x) (w : stop = l.size) :
l.filterMapM f 0 stop = l.unattach.filterMapM g := by
subst w
rcases l with ⟨l⟩
simp
rw [List.filterMapM_subtype hf]

@[simp] theorem flatMapM_subtype [Monad m] [LawfulMonad m] {p : α → Prop} {l : Array { x // p x }}
{f : { x // p x } → m (Array β)} {g : α → m (Array β)} (hf : ∀ x h, f ⟨x, h⟩ = g x) :
(l.flatMapM f) = l.unattach.flatMapM g := by
rcases l with ⟨l⟩
simp
rw [List.flatMapM_subtype]
simp [hf]

end Array
2 changes: 1 addition & 1 deletion src/Init/Data/List/Control.lean
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Applies the monadic function `f` on every element `x` in the list, left-to-right
results `y` for which `f x` returns `some y`.
-/
@[inline]
def filterMapM {m : Type u → Type v} [Monad m] {α β : Type u} (f : α → m (Option β)) (as : List α) : m (List β) :=
def filterMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (Option β)) (as : List α) : m (List β) :=
let rec @[specialize] loop
| [], bs => pure bs.reverse
| a :: as, bs => do
Expand Down
21 changes: 21 additions & 0 deletions src/Init/Data/Vector/MapIdx.lean
Original file line number Diff line number Diff line change
Expand Up @@ -363,4 +363,25 @@ theorem mapIdx_eq_mkVector_iff {l : Vector α n} {f : Nat → α → β} {b : β
rcases l with ⟨l, rfl⟩
simp [Array.mapIdx_reverse]

theorem toArray_mapFinIdxM [Monad m] [LawfulMonad m]
(a : Vector α n) (f : (i : Nat) → α → (h : i < n) → m β) :
toArray <$> a.mapFinIdxM f = a.toArray.mapFinIdxM
(fun i x h => f i x (size_toArray a ▸ h)) := by
let rec go (i j : Nat) (inv : i + j = n) (bs : Vector β (n - i)) :
toArray <$> mapFinIdxM.map a f i j inv bs
= Array.mapFinIdxM.map a.toArray (fun i x h => f i x (size_toArray a ▸ h))
i j (size_toArray _ ▸ inv) bs.toArray := by
match i with
| 0 => simp only [mapFinIdxM.map, map_pure, Array.mapFinIdxM.map, Nat.sub_zero]
| k + 1 =>
simp only [mapFinIdxM.map, map_bind, Array.mapFinIdxM.map, getElem_toArray]
conv => lhs; arg 2; intro; rw [go]
rfl
simp only [mapFinIdxM, Array.mapFinIdxM, size_toArray]
exact go _ _ _ _

theorem toArray_mapIdxM [Monad m] [LawfulMonad m] (a : Vector α n) (f : Nat → α → m β) :
toArray <$> a.mapIdxM f = a.toArray.mapIdxM f := by
exact toArray_mapFinIdxM _ _

end Vector
14 changes: 6 additions & 8 deletions src/Init/Data/Vector/Monadic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ open Nat

/-! ## Monadic operations -/

theorem map_toArray_inj [Monad m] [LawfulMonad m] [Nonempty α]
{v₁ : m (Vector α n)} {v₂ : m (Vector α n)} (w : toArray <$> v₁ = toArray <$> v₂) :
v₁ = v₂ := by
apply map_inj_of_inj ?_ w
simp
@[simp] theorem map_toArray_inj [Monad m] [LawfulMonad m]
{v₁ : m (Vector α n)} {v₂ : m (Vector α n)} :
toArray <$> v₁ = toArray <$> v₂ ↔ v₁ = v₂ :=
_root_.map_inj_right (by simp)

/-! ### mapM -/

Expand All @@ -39,11 +38,10 @@ theorem map_toArray_inj [Monad m] [LawfulMonad m] [Nonempty α]
unfold mapM.go
simp

-- The `[Nonempty β]` hypothesis should be avoidable by unfolding `mapM` directly.
@[simp] theorem mapM_append [Monad m] [LawfulMonad m] [Nonempty β]
@[simp] theorem mapM_append [Monad m] [LawfulMonad m]
(f : α → m β) {l₁ : Vector α n} {l₂ : Vector α n'} :
(l₁ ++ l₂).mapM f = (return (← l₁.mapM f) ++ (← l₂.mapM f)) := by
apply map_toArray_inj
apply map_toArray_inj.mp
suffices toArray <$> (l₁ ++ l₂).mapM f = (return (← toArray <$> l₁.mapM f) ++ (← toArray <$> l₂.mapM f)) by
rw [this]
simp only [bind_pure_comp, Functor.map_map, bind_map_left, map_bind, toArray_append]
Expand Down

0 comments on commit af385d7

Please sign in to comment.