Skip to content

Commit

Permalink
feat: refactor List/Array.mapFinIdx to unbundle the Fin argument (#6697)
Browse files Browse the repository at this point in the history
This PR changes the arguments of `List/Array.mapFinIdx` from `(f : Fin
as.size → α → β)` to `(f : (i : Nat) → α → (h : i < as.size) → β)`, in
line with the API design elsewhere for `List/Array`.
  • Loading branch information
kim-em authored Jan 19, 2025
1 parent b289b66 commit 35bbb48
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 93 deletions.
8 changes: 4 additions & 4 deletions src/Init/Data/Array/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def mapM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α
/-- Variant of `mapIdxM` which receives the index as a `Fin as.size`. -/
@[inline]
def mapFinIdxM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m]
(as : Array α) (f : Fin as.size → α → m β) : m (Array β) :=
(as : Array α) (f : (i : Nat) → α → (h : i < as.size) → m β) : m (Array β) :=
let rec @[specialize] map (i : Nat) (j : Nat) (inv : i + j = as.size) (bs : Array β) : m (Array β) := do
match i, inv with
| 0, _ => pure bs
Expand All @@ -464,12 +464,12 @@ def mapFinIdxM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m]
rw [← inv, Nat.add_assoc, Nat.add_comm 1 j, Nat.add_comm]
apply Nat.le_add_right
have : i + (j + 1) = as.size := by rw [← inv, Nat.add_comm j 1, Nat.add_assoc]
map i (j+1) this (bs.push (← f ⟨j, j_lt⟩ (as.get j j_lt)))
map i (j+1) this (bs.push (← f j (as.get j j_lt) j_lt))
map as.size 0 rfl (mkEmpty as.size)

@[inline]
def mapIdxM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : Nat → α → m β) (as : Array α) : m (Array β) :=
as.mapFinIdxM fun i a => f i a
as.mapFinIdxM fun i a _ => f i a

@[inline]
def findSomeM? {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α → m (Option β)) (as : Array α) : m (Option β) := do
Expand Down Expand Up @@ -588,7 +588,7 @@ def map {α : Type u} {β : Type v} (f : α → β) (as : Array α) : Array β :

/-- Variant of `mapIdx` which receives the index as a `Fin as.size`. -/
@[inline]
def mapFinIdx {α : Type u} {β : Type v} (as : Array α) (f : Fin as.size → α → β) : Array β :=
def mapFinIdx {α : Type u} {β : Type v} (as : Array α) (f : (i : Nat) → α → (h : i < as.size) → β) : Array β :=
Id.run <| as.mapFinIdxM f

@[inline]
Expand Down
65 changes: 33 additions & 32 deletions src/Init/Data/Array/MapIdx.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,81 +12,82 @@ namespace Array
/-! ### mapFinIdx -/

-- This could also be proved from `SatisfiesM_mapIdxM` in Batteries.
theorem mapFinIdx_induction (as : Array α) (f : Fin as.size → α → β)
theorem mapFinIdx_induction (as : Array α) (f : (i : Nat) → α → (h : i < as.size) → β)
(motive : Nat → Prop) (h0 : motive 0)
(p : Fin as.size → β → Prop)
(hs : ∀ i, motive i.1 → p i (f i as[i]) ∧ motive (i + 1)) :
(p : (i : Nat) → β → (h : i < as.size)Prop)
(hs : ∀ i h, motive i → p i (f i as[i] h) h ∧ motive (i + 1)) :
motive as.size ∧ ∃ eq : (Array.mapFinIdx as f).size = as.size,
∀ i h, p ⟨i, h⟩ ((Array.mapFinIdx as f)[i]) := by
let rec go {bs i j h} (h₁ : j = bs.size) (h₂ : ∀ i h h', p ⟨i, h⟩ bs[i]) (hm : motive j) :
∀ i h, p i ((Array.mapFinIdx as f)[i]) h := by
let rec go {bs i j h} (h₁ : j = bs.size) (h₂ : ∀ i h h', p i bs[i] h) (hm : motive j) :
let arr : Array β := Array.mapFinIdxM.map (m := Id) as f i j h bs
motive as.size ∧ ∃ eq : arr.size = as.size, ∀ i h, p ⟨i, h⟩ arr[i] := by
motive as.size ∧ ∃ eq : arr.size = as.size, ∀ i h, p i arr[i] h := by
induction i generalizing j bs with simp [mapFinIdxM.map]
| zero =>
have := (Nat.zero_add _).symm.trans h
exact ⟨this ▸ hm, h₁ ▸ this, fun _ _ => h₂ ..⟩
| succ i ih =>
apply @ih (bs.push (f ⟨j, by omega⟩ as[j])) (j + 1) (by omega) (by simp; omega)
apply @ih (bs.push (f j as[j] (by omega))) (j + 1) (by omega) (by simp; omega)
· intro i i_lt h'
rw [getElem_push]
split
· apply h₂
· simp only [size_push] at h'
obtain rfl : i = j := by omega
apply (hs ⟨i, by omega hm).1
· exact (hs ⟨j, by omega hm).2
apply (hs i (by omega) hm).1
· exact (hs j (by omega) hm).2
simp [mapFinIdx, mapFinIdxM]; exact go rfl nofun h0

theorem mapFinIdx_spec (as : Array α) (f : Fin as.size → α → β)
(p : Fin as.size → β → Prop) (hs : ∀ i, p i (f i as[i])) :
theorem mapFinIdx_spec (as : Array α) (f : (i : Nat) → α → (h : i < as.size) → β)
(p : (i : Nat) → β → (h : i < as.size) → Prop) (hs : ∀ i h, p i (f i as[i] h) h) :
∃ eq : (Array.mapFinIdx as f).size = as.size,
∀ i h, p ⟨i, h⟩ ((Array.mapFinIdx as f)[i]) :=
(mapFinIdx_induction _ _ (fun _ => True) trivial p fun _ _ => ⟨hs .., trivial⟩).2
∀ i h, p i ((Array.mapFinIdx as f)[i]) h :=
(mapFinIdx_induction _ _ (fun _ => True) trivial p fun _ _ _ => ⟨hs .., trivial⟩).2

@[simp] theorem size_mapFinIdx (a : Array α) (f : Fin a.size → α → β) : (a.mapFinIdx f).size = a.size :=
(mapFinIdx_spec (p := fun _ _ => True) (hs := fun _ => trivial)).1
@[simp] theorem size_mapFinIdx (a : Array α) (f : (i : Nat) → α → (h : i < a.size) → β) :
(a.mapFinIdx f).size = a.size :=
(mapFinIdx_spec (p := fun _ _ _ => True) (hs := fun _ _ => trivial)).1

@[simp] theorem size_zipWithIndex (as : Array α) : as.zipWithIndex.size = as.size :=
Array.size_mapFinIdx _ _

@[simp] theorem getElem_mapFinIdx (a : Array α) (f : Fin a.size → α → β) (i : Nat)
@[simp] theorem getElem_mapFinIdx (a : Array α) (f : (i : Nat) → α → (h : i < a.size) → β) (i : Nat)
(h : i < (mapFinIdx a f).size) :
(a.mapFinIdx f)[i] = f ⟨i, by simp_all⟩ (a[i]'(by simp_all)) :=
(mapFinIdx_spec _ _ (fun i b => b = f i a[i]) fun _ => rfl).2 i _
(a.mapFinIdx f)[i] = f i (a[i]'(by simp_all)) (by simp_all) :=
(mapFinIdx_spec _ _ (fun i b h => b = f i a[i] h) fun _ _ => rfl).2 i _

@[simp] theorem getElem?_mapFinIdx (a : Array α) (f : Fin a.size → α → β) (i : Nat) :
@[simp] theorem getElem?_mapFinIdx (a : Array α) (f : (i : Nat) → α → (h : i < a.size) → β) (i : Nat) :
(a.mapFinIdx f)[i]? =
a[i]?.pbind fun b h => f ⟨i, (getElem?_eq_some_iff.1 h).1⟩ b := by
a[i]?.pbind fun b h => f i b (getElem?_eq_some_iff.1 h).1 := by
simp only [getElem?_def, size_mapFinIdx, getElem_mapFinIdx]
split <;> simp_all

@[simp] theorem toList_mapFinIdx (a : Array α) (f : Fin a.size → α → β) :
(a.mapFinIdx f).toList = a.toList.mapFinIdx (fun i a => f ⟨i, by simp⟩ a) := by
@[simp] theorem toList_mapFinIdx (a : Array α) (f : (i : Nat) → α → (h : i < a.size) → β) :
(a.mapFinIdx f).toList = a.toList.mapFinIdx (fun i a h => f i a (by simpa)) := by
apply List.ext_getElem <;> simp

/-! ### mapIdx -/

theorem mapIdx_induction (f : Nat → α → β) (as : Array α)
(motive : Nat → Prop) (h0 : motive 0)
(p : Fin as.size → β → Prop)
(hs : ∀ i, motive i.1 → p i (f i as[i]) ∧ motive (i + 1)) :
(p : (i : Nat) → β → (h : i < as.size)Prop)
(hs : ∀ i h, motive i → p i (f i as[i]) h ∧ motive (i + 1)) :
motive as.size ∧ ∃ eq : (as.mapIdx f).size = as.size,
∀ i h, p ⟨i, h⟩ ((as.mapIdx f)[i]) :=
mapFinIdx_induction as (fun i a => f i a) motive h0 p hs
∀ i h, p i ((as.mapIdx f)[i]) h :=
mapFinIdx_induction as (fun i a _ => f i a) motive h0 p hs

theorem mapIdx_spec (f : Nat → α → β) (as : Array α)
(p : Fin as.size → β → Prop) (hs : ∀ i, p i (f i as[i])) :
(p : (i : Nat) → β → (h : i < as.size) → Prop) (hs : ∀ i h, p i (f i as[i]) h) :
∃ eq : (as.mapIdx f).size = as.size,
∀ i h, p ⟨i, h⟩ ((as.mapIdx f)[i]) :=
(mapIdx_induction _ _ (fun _ => True) trivial p fun _ _ => ⟨hs .., trivial⟩).2
∀ i h, p i ((as.mapIdx f)[i]) h :=
(mapIdx_induction _ _ (fun _ => True) trivial p fun _ _ _ => ⟨hs .., trivial⟩).2

@[simp] theorem size_mapIdx (f : Nat → α → β) (as : Array α) : (as.mapIdx f).size = as.size :=
(mapIdx_spec (p := fun _ _ => True) (hs := fun _ => trivial)).1
(mapIdx_spec (p := fun _ _ _ => True) (hs := fun _ _ => trivial)).1

@[simp] theorem getElem_mapIdx (f : Nat → α → β) (as : Array α) (i : Nat)
(h : i < (as.mapIdx f).size) :
(as.mapIdx f)[i] = f i (as[i]'(by simp_all)) :=
(mapIdx_spec _ _ (fun i b => b = f i as[i]) fun _ => rfl).2 i (by simp_all)
(mapIdx_spec _ _ (fun i b h => b = f i as[i]) fun _ _ => rfl).2 i (by simp_all)

@[simp] theorem getElem?_mapIdx (f : Nat → α → β) (as : Array α) (i : Nat) :
(as.mapIdx f)[i]? =
Expand All @@ -101,7 +102,7 @@ end Array

namespace List

@[simp] theorem mapFinIdx_toArray (l : List α) (f : Fin l.length → α → β) :
@[simp] theorem mapFinIdx_toArray (l : List α) (f : (i : Nat) → α → (h : i < l.length) → β) :
l.toArray.mapFinIdx f = (l.mapFinIdx f).toArray := by
ext <;> simp

Expand Down
Loading

0 comments on commit 35bbb48

Please sign in to comment.