diff --git a/SampCert/Foundations/UniformByte.lean b/SampCert/Foundations/UniformByte.lean index b3fb22a4..7a91a03e 100644 --- a/SampCert/Foundations/UniformByte.lean +++ b/SampCert/Foundations/UniformByte.lean @@ -6,19 +6,42 @@ Authors: Jean-Baptiste Tristan import SampCert.SLang import SampCert.Util.Util import SampCert.Foundations.Auto +import Mathlib.Probability.Distributions.Uniform +import Mathlib.Data.Nat.Log +import SampCert.Foundations.Monad + /-! # ``probUniformByte`` Properties This file contains lemmas about ``probUniformByte``, a ``SLang`` sampler for the uniform distribution on bytes. + +It also contains the derivation that ``probUniformP2`` is a uniform distribution. -/ open Classical Nat PMF + + namespace SLang + + + +local instance : Finite UInt8 := by + constructor + · apply Equiv.ofBijective (fun v => v.val) + apply Function.bijective_iff_has_inverse.mpr + exists (fun v => {val := v : UInt8}) + simp [Function.RightInverse, Function.LeftInverse] + + + +/-- +ProbUniformByte is a proper distribution +-/ def probUniformByte_normalizes : HasSum probUniformByte 1 := by rw [Summable.hasSum_iff ENNReal.summable] unfold SLang.probUniformByte @@ -38,6 +61,422 @@ def probUniformByte_normalizes : HasSum probUniformByte 1 := by simp [Function.RightInverse, Function.LeftInverse] simp [ENNReal.toReal_inv] +/-- +ProbUniformByte as a PMF +-/ def probUniformByte_PMF : PMF UInt8 := ⟨ probUniformByte, probUniformByte_normalizes ⟩ +/-- +Evaluation of ``probUniformByteUpperBits`` for inside the support +-/ +def probUniformByteUpperBits_eval_support {i x : ℕ} (Hx : x < 2 ^ (min 8 i)) : + probUniformByteUpperBits i x = 2^(8 - i) / UInt8.size := by + simp [probUniformByteUpperBits] + rw [Nat.sub_eq_max_sub] + simp [SLang.probBind, SLang.probPure, probUniformByte] + cases (Classical.em (i < 8)) + + · -- Simplify body + rw [max_eq_left (by linarith)] + rw [min_eq_right (by linarith)] at Hx + conv => + enter [1, 1, a] + rw [Nat.shiftRight_eq_div_pow] + conv => + enter [1, 1, a] + rw [<- mul_one (256)⁻¹] + rw [<- mul_zero (256)⁻¹] + rw [<- mul_ite] + rw [ENNReal.tsum_mul_left] + rw [division_def] + rw [mul_comm] + congr 1 + + -- Restruct sum to type where body is constant + rw [<- (@tsum_subtype_eq_of_support_subset _ _ _ _ _ { i_1 : UInt8 | x = i_1.toNat / 2 ^ (8 - i) } ?G1)] + case G1 => simp [Function.support] + generalize HT : { i_1 : UInt8 | x = i_1.toNat / 2 ^ (8 - i) } = T + have H (x1 : T) : (@ite _ (x = (x1 : UInt8).toNat / 2 ^ (8 - i)) _ (1 : ENNReal) (0 : ENNReal)) = 1 := by + apply ite_eq_iff.mpr + simp + rcases x1 + rename_i h val property + subst HT + simp_all only + simp_all only [Set.mem_setOf_eq] + conv => + enter [1, 1, a] + rw [H a] + clear H + + -- Rewrite to real sum + -- Simplify me + suffices ENNReal.toReal (∑' (_ : T), 1) = ENNReal.toReal (2 ^ (8 - i)) by + refine (ENNReal.toReal_eq_toReal_iff' ?G1 ?G2).mp this + case G1 => + rw [tsum_eq_finsum ?G1] + case G1 => + simp [Function.support] + apply Set.finite_univ_iff.mpr + apply Subtype.finite + simp + have R := @finsum_induction ENNReal T _ (fun _ => 1) (fun z => z ≠ ⊤) (by simp) (by aesop) (by simp) + simp at R + trivial + case G2 => simp + + -- Rewrite to set cardinality + rw [ENNReal.tsum_toReal_eq ?G1] + case G1 => simp + simp [tsum_const] + + + -- Evaluate set cardinality using bijection + -- Simplify me! + rw [@Nat.card_eq_of_equiv_fin T (2^(8 - i)) ?G1] + case G1 => + rw [<- HT] + simp + apply Equiv.ofBijective + case f => + intro v + rcases v with ⟨ v', Hv' ⟩ + exact + ⟨ v'.toNat - x * (2 ^ (8 - i)), + by + have W := (Nat.le_div_iff_mul_le' (by simp)).mp (Eq.le Hv') + have W' := (Nat.div_lt_iff_lt_mul (by simp)).mp (Nat.lt_succ_iff.mpr (Eq.le (Eq.symm Hv'))) + have W'' : v'.toNat - x * 2 ^ (8 - i) < x.succ * 2 ^ (8 - i) - x * 2 ^ (8 - i) := by + exact Nat.sub_lt_sub_right W W' + suffices (x.succ * 2 ^ (8 - i) - x * 2 ^ (8 - i)) ≤ 2 ^ (8 - i) by + exact Nat.lt_of_lt_of_le W'' this + rw [← Nat.sub_mul] + simp ⟩ + apply Function.bijective_iff_has_inverse.mpr + + -- Set up the bijections + -- Is there a simpler way to do this? + exists ?G1 + case G1 => + intro f + rcases f with ⟨ f', Hf' ⟩ + exact + ⟨ UInt8.ofNatCore (f' + x * 2^(8-i)) + (by + rw [UInt8.size] + apply (@LT.lt.trans_le _ _ _ (2^(8-i) + x * 2^(8-i))) + · exact Nat.add_lt_add_right Hf' (x * 2 ^ (8 - i)) + · conv => + enter [1, 1] + rw [<- one_mul (2^(8-i))] + rw [<- add_mul] + have Z : (1 + x) ≤ 2^i := by linarith + have Z' : (1 + x) * (2^(8-i)) ≤ 2^i * (2^(8-i)) := by + exact Nat.mul_le_mul_right (2 ^ (8 - i)) Z + apply le_trans Z' + apply Eq.le + rw [<- pow_add] + have H256 : 256 = 2^8 := by simp + rw [H256] + clear H256 + congr 1 + apply add_sub_of_le + apply le_of_succ_le + trivial + ), + (by + unfold UInt8.ofNatCore + unfold UInt8.toNat + simp + apply (nat_div_eq_le_lt_iff (by simp)).mpr + apply And.intro + · exact Nat.le_add_left (x * 2 ^ (8 - i)) f' + · linarith )⟩ + dsimp [Function.RightInverse, Function.LeftInverse] + apply And.intro + · intro x' + rcases x' with ⟨ ⟨ x'', H2x'' ⟩, Hx'' ⟩ + unfold UInt8.ofNatCore + unfold UInt8.toNat + simp + congr + apply Nat.sub_add_cancel + rw [Hx''] + rw [UInt8.toNat] + apply (Nat.le_div_iff_mul_le (by simp)).mp + simp + · intro x' + rcases x' with ⟨ x'', Hx'' ⟩ + simp [UInt8.ofNatCore] + rw [UInt8.toNat] + simp + simp + · rw [max_eq_right (by linarith)] + rw [min_eq_left (by linarith)] at Hx + rw [tsum_eq_single (UInt8.ofNatCore x Hx) ?G1] + case G1 => + intro b' Hb' + simp + intro Hx' + exfalso + apply Hb' + rcases b' with ⟨ ⟨ b'', Hb'' ⟩ ⟩ + simp [UInt8.ofNatCore] + congr + rw [Hx'] + simp [UInt8.toNat] + simp + intro HK + exfalso + apply HK + rfl + + +/-- +Evaluation of ``probUniformByteUpperBits`` for zero-shifts outside of the support +-/ +def probUniformByteUpperBits_eval_zero {i x : ℕ} (Hx : x ≥ 2 ^ (min 8 i)) : + probUniformByteUpperBits i x = 0 := by + simp [probUniformByteUpperBits] + rw [Nat.sub_eq_max_sub] + simp [SLang.probBind, SLang.probPure, probUniformByte] + intro v H1 + exfalso + cases (Classical.em (i < 8)) + · -- i < 8 + rename_i Hi + rw [max_eq_left (by linarith)] at H1 + rw [min_eq_right (by linarith)] at Hx + simp_all + rw [Nat.shiftRight_eq_div_pow] at * + have H2 := UInt8.toNat_lt v + apply Nat.mul_le_of_le_div (2 ^ (8 - i)) (2 ^ i) v.toNat at Hx + rw [<- pow_add] at Hx + have X : (i + (8 - i)) = 8 := by + apply add_sub_of_le + linarith + rw [X] at Hx + clear X + linarith + · -- i >= 8 + rename_i Hi + rw [max_eq_right (by linarith)] at H1 + rw [min_eq_left (by linarith)] at Hx + have H2 := UInt8.toNat_lt v + simp_all + linarith + + +lemma UIint8_cast_lt_size (a : UInt8) : a.toNat < UInt8.size := by + rcases a with ⟨ ⟨ a', Ha' ⟩ ⟩ + rw [UInt8.toNat] + simp + apply Ha' + + +/-- +Evaluation of ``probUniformP2`` for inside the support +-/ +def probUniformP2_eval_support {i x : ℕ} (Hx : x < 2 ^ i): + probUniformP2 i x = (1 / 2 ^ i) := by + revert x + induction' i using Nat.strong_induction_on with i ih + rw [probUniformP2] + split + · intro x Hx' + rename_i h + rw [probUniformByteUpperBits_eval_support] + · rw [UInt8.size] + have X : 256 = 2^8 := by simp + rw [X] + clear X + rw [cast_pow] + apply (ENNReal.div_eq_div_iff _ _ _ _).mpr <;> try simp + rw [← pow_add] + congr 1 + rw [add_tsub_cancel_iff_le] + linarith + · rw [min_eq_right ?G1] + case G1 => linarith + assumption + · intro x Hx' + simp [probUniformByte] + + -- Simplify, rewrite to indicator function + conv => + enter [1, 1, a] + rw [<- ENNReal.tsum_mul_left] + enter [1, b] + rw [<- mul_one (probUniformP2 (i - 8) b)] + rw [<- mul_zero (probUniformP2 (i - 8) b)] + rw [<- mul_ite] + rw [<- mul_assoc] + + -- Similar to the Laplace proof: use Euclidean division to rewrite + -- to product of indicator functions + rcases @euclidean_division x UInt8.size (by simp) with ⟨ p, q, Hq, Hx ⟩ + have X (a : UInt8) (b : ℕ) D : + (@ite _ (q + UInt8.size * p = UInt8.size * b + a.toNat) D (1 : ENNReal) 0) = + (if p = b then (1 : ENNReal) else 0) * (if q = a.toNat then (1 : ENNReal) else 0) := by + split + · rename_i He + conv at He => + enter [2] + rw [add_comm] + have R := (euclidean_division_uniquness _ _ _ _ (by simp) Hq ?G3).mp He + case G3 => apply UIint8_cast_lt_size + rcases R with ⟨ R1 , R2 ⟩ + simp_all + · rename_i He + suffices (p ≠ b) ∨ (q ≠ a.toNat) by + rcases this with Ht | Ht + · rw [ite_eq_right_iff.mpr] + · simp + · intro Hk + exfalso + apply Ht Hk + · rw [mul_comm] + rw [ite_eq_right_iff.mpr] + · simp + · intro Hk + exfalso + apply Ht Hk + simp + apply (Decidable.not_and_iff_or_not (p = b) (q = a.toNat)).mp + intro HK + apply He + rw [And.comm] at HK + have _ := (euclidean_division_uniquness _ _ _ _ (by simp) Hq ?G3).mpr HK + case G3 => apply UIint8_cast_lt_size + linarith + conv => + enter [1, 1, a, 1, b] + rw [Hx] + rw [X a b] + clear X + + -- Separate the sums + conv => + enter [1, 1, a, 1, b] + repeat rw [mul_assoc] + conv => + enter [1, 1, a] + rw [ENNReal.tsum_mul_left] + rw [ENNReal.tsum_mul_left] + conv => + enter [1, 2, 1, a, 1, b] + rw [<- mul_assoc] + rw [mul_comm] + conv => + enter [1, 2, 1, a] + rw [ENNReal.tsum_mul_left] + conv => + enter [1, 2] + rw [ENNReal.tsum_mul_right] + simp + + -- Simplify the singleton sums + rw [tsum_eq_single p ?G1] + case G1 => + intro _ HK + simp + intro HK' + exfalso + exact HK (id (Eq.symm HK')) + have X : (UInt8.ofNatCore q Hq).toNat = q := by + rw [UInt8.ofNatCore, UInt8.toNat] + rw [tsum_eq_single (UInt8.ofNatCore q Hq) ?G1] + case G1 => + simp + intro b HK' HK'' + apply HK' + rw [UInt8.ofNatCore] + rcases b with ⟨ ⟨ b' , Hb' ⟩ ⟩ + congr + rw [HK''] + rw [UInt8.toNat] + rw [X] + clear X + simp + + -- Apply the IH + rw [ih] + · simp + rw [<- ENNReal.mul_inv ?G1 ?G2] + case G1 => simp + case G2 => simp + congr 1 + have H256 : (256 : ENNReal) = (256 : ℕ) := by simp + rw [H256] + have X : (256 : ℕ) = 2^8 := by simp + rw [X] + rw [cast_pow] + rw [cast_two] + rw [← pow_add] + congr 1 + apply add_sub_of_le + linarith + · simp + linarith + · rw [Hx] at Hx' + have Hx'' : UInt8.size * p < OfNat.ofNat 2 ^ i := by + apply Classical.byContradiction + intro HK + linarith + rw [UInt8.size] at Hx'' + have Y : 256 = 2^8 := by simp + rw [Y] at Hx'' + clear Y + have W := (Nat.lt_div_iff_mul_lt ?G1 _).mpr Hx'' + case G1 => + apply Nat.pow_dvd_pow (OfNat.ofNat 2) + linarith + apply (LT.lt.trans_eq W) + apply Nat.pow_div <;> linarith + +/-- +Evaluation of ``probUniformP2`` for zero-shifts outside of the support +-/ +def probUniformP2_eval_zero {i x : ℕ} (Hx : x ≥ 2 ^ i): + probUniformP2 i x = 0 := by + revert x + induction' i using Nat.strong_induction_on with i ih + intro x Hk + rw [probUniformP2] + split + · apply probUniformByteUpperBits_eval_zero + rw [min_eq_right] + · trivial + · linarith + · simp + intro i1 + right + intro i2 Hi + apply ih + · apply sub_lt + · linarith + · simp + · rw [Hi] at Hk + simp_all + suffices 2^ i ≤ UInt8.size * i2 by + rw [UInt8.size] at this + rw [← Nat.pow_div (by trivial) ?G1] + case G1 => simp + exact Nat.div_le_of_le_mul this + have H : (i1.toNat < UInt8.size) := by exact UIint8_cast_lt_size i1 + + -- Establish this bound by the uniqueness of Euclidean division + rcases @euclidean_division (2^i) (2^8) (by simp) with ⟨ p, q, Hq, H ⟩ + have Hple : (p ≤ i2) := by linarith + have Heuc' : q + 2 ^ 8 * p = 0 + 2 ^ 8 * (2 ^ (i - 8)) := by + rw [<- H] + rw [zero_add] + rw [<- pow_add] + congr 1 + symm + apply add_sub_of_le + trivial + have W := (euclidean_division_uniquness _ _ _ _ (by simp) Hq (by simp)).mp Heuc' + simp_all + end SLang diff --git a/SampCert/Foundations/UniformP2.lean b/SampCert/Foundations/UniformP2.lean index ed0563e1..ebe1424e 100644 --- a/SampCert/Foundations/UniformP2.lean +++ b/SampCert/Foundations/UniformP2.lean @@ -9,6 +9,7 @@ import Mathlib.Data.Nat.Log import SampCert.Util.Util import SampCert.Foundations.Monad import SampCert.Foundations.Auto +import SampCert.Foundations.UniformByte /-! # ``probUniformP2`` Properties @@ -17,105 +18,30 @@ This file contains lemmas about ``probUniformP2``, a ``SLang`` sampler for the uniform distribution on spaces whose size is a power of two. -/ - open Classical Nat PMF namespace SLang -@[simp] -lemma sum_indicator_finrange_gen (n : Nat) (x : Nat) : - (x < n → (∑' (i : Fin n), @ite ENNReal (x = ↑i) (propDecidable (x = ↑i)) 1 0) = (1 : ENNReal)) - ∧ (x >= n → (∑' (i : Fin n), @ite ENNReal (x = ↑i) (propDecidable (x = ↑i)) 1 0) = (0 : ENNReal)) := by - revert x - induction n - . intro x - simp - . rename_i n IH - intro x - constructor - . intro cond - have OR : x = n ∨ x < n := by exact Order.lt_succ_iff_eq_or_lt.mp cond - cases OR - . rename_i cond' - have IH' := IH x - cases IH' - rename_i left right - have cond'' : x ≥ n := by exact Nat.le_of_eq (id cond'.symm) - have right' := right cond'' - rw [tsum_fintype] at * - rw [Fin.sum_univ_castSucc] - simp [right'] - simp [cond'] - . rename_i cond' - have IH' := IH x - cases IH' - rename_i left right - have left' := left cond' - rw [tsum_fintype] at * - rw [Fin.sum_univ_castSucc] - simp [left'] - have neq : x ≠ n := by exact Nat.ne_of_lt cond' - simp [neq] - . intro cond - have succ_gt : x ≥ n := by exact lt_succ.mp (le.step cond) - have IH' := IH x - cases IH' - rename_i left right - have right' := right succ_gt - rw [tsum_fintype] - rw [Fin.sum_univ_castSucc] - simp - constructor - . simp at right' - intro x' - apply right' x' - . have neq : x ≠ n := by exact Nat.ne_of_gt cond - simp [neq] - - -/-- -Computes the sum of an indicator variable (indicating inside the support of ``Fin n``) over the space ``Fin n``. --/ -theorem sum_indicator_finrange (n : Nat) (x : Nat) (h : x < n) : - (∑' (i : Fin n), @ite ENNReal (x = ↑i) (propDecidable (x = ↑i)) 1 0) = (1 : ENNReal) := by - have H := sum_indicator_finrange_gen n x - cases H - rename_i left right - apply left - trivial - /-- Evaluates the ``probUniformP2`` distribution at a point inside of its support. -/ @[simp] -theorem probUniformP2_apply (n : PNat) (x : Nat) (h : x < 2 ^ (log 2 n)) : +theorem UniformPowerOfTwoSample_apply (n : PNat) (x : Nat) (h : x < 2 ^ (log 2 n)) : (UniformPowerOfTwoSample n) x = 1 / (2 ^ (log 2 n)) := by - simp only [UniformPowerOfTwoSample, Lean.Internal.coeM, Bind.bind, Pure.pure, CoeT.coe, - CoeHTCT.coe, CoeHTC.coe, CoeOTC.coe, CoeOut.coe, toSLang_apply, PMF.bind_apply, - uniformOfFintype_apply, Fintype.card_fin, cast_pow, cast_ofNat, PMF.pure_apply, one_div] - rw [ENNReal.tsum_mul_left] - rw [sum_indicator_finrange (2 ^ (log 2 n)) x] - . simp - . trivial + simp [UniformPowerOfTwoSample] + rw [probUniformP2_eval_support] + · simp + trivial /-- Evaluates the ``probUniformP2`` distribution at a point outside of its support -/ @[simp] -theorem probUniformP2_apply' (n : PNat) (x : Nat) (h : x ≥ 2 ^ (log 2 n)) : +theorem UniformPowerOfTwoSample_apply' (n : PNat) (x : Nat) (h : x ≥ 2 ^ (log 2 n)) : UniformPowerOfTwoSample n x = 0 := by simp [UniformPowerOfTwoSample] - intro i - cases i - rename_i i P - simp only - have A : i < 2 ^ log 2 ↑n ↔ ¬ i ≥ 2 ^ log 2 ↑n := by exact lt_iff_not_le - rw [A] at P - simp at P - by_contra CONTRA - subst CONTRA - replace A := A.1 P - contradiction + rw [probUniformP2_eval_zero] + trivial lemma if_simpl_up2 (n : PNat) (x x_1: Fin (2 ^ log 2 ↑n)) : (@ite ENNReal (x_1 = x) (propDecidable (x_1 = x)) 0 (@ite ENNReal ((@Fin.val (2 ^ log 2 ↑n) x) = (@Fin.val (2 ^ log 2 ↑n) x_1)) (propDecidable ((@Fin.val (2 ^ log 2 ↑n) x) = (@Fin.val (2 ^ log 2 ↑n) x_1))) 1 0)) = 0 := by @@ -135,32 +61,23 @@ lemma if_simpl_up2 (n : PNat) (x x_1: Fin (2 ^ log 2 ↑n)) : /-- The ``SLang`` term ``uniformPowerOfTwo`` is a proper distribution on ``ℕ``. -/ -theorem probUniformP2_normalizes (n : PNat) : +theorem UniformPowerOfTwoSample_normalizes (n : PNat) : ∑' i : ℕ, UniformPowerOfTwoSample n i = 1 := by + rw [UniformPowerOfTwoSample] rw [← @sum_add_tsum_nat_add' _ _ _ _ _ _ (2 ^ (log 2 n))] - . simp only [ge_iff_le, le_add_iff_nonneg_left, _root_.zero_le, probUniformP2_apply', - tsum_zero, add_zero] - simp only [UniformPowerOfTwoSample, Lean.Internal.coeM, Bind.bind, Pure.pure, CoeT.coe, - CoeHTCT.coe, CoeHTC.coe, CoeOTC.coe, CoeOut.coe, toSLang_apply, PMF.bind_apply, - uniformOfFintype_apply, Fintype.card_fin, cast_pow, cast_ofNat, PMF.pure_apply] - rw [Finset.sum_range] + · rw [Finset.sum_range] conv => - left - right - intro x - rw [ENNReal.tsum_mul_left] - rw [ENNReal.tsum_eq_add_tsum_ite x] - right - right - right - intro x_1 - rw [if_simpl_up2] + enter [1] + congr + · enter [2, a] + skip + rw [probUniformP2_eval_support (by exact a.isLt)] + · enter [1, a] + rw [probUniformP2_eval_zero (by exact Nat.le_add_left (2 ^ log 2 ↑n) a)] simp - rw [ENNReal.inv_pow] - rw [← mul_pow] - rw [two_mul] - rw [ENNReal.inv_two_add_inv_two] - rw [one_pow] - . exact ENNReal.summable + apply ENNReal.mul_inv_cancel + · simp + · simp + exact ENNReal.summable end SLang diff --git a/SampCert/SLang.lean b/SampCert/SLang.lean index 7f54602f..5566b5fb 100644 --- a/SampCert/SLang.lean +++ b/SampCert/SLang.lean @@ -83,13 +83,31 @@ Uniform distribution on a byte @[extern "prob_UniformByte"] def probUniformByte : SLang UInt8 := (fun _ => 1 / UInt8.size) +/-- +Upper i bits from a unifomly sampled byte +-/ +def probUniformByteUpperBits (i : ℕ) : SLang ℕ := do + let w <- probUniformByte + return w.toNat.shiftRight (8 - i) + +/-- +Uniform distribution on the set [0, 2^i) ⊆ ℕ +-/ +def probUniformP2 (i : ℕ) : SLang ℕ := + if (i < 8) + then probUniformByteUpperBits i + else do + let v <- probUniformByte + let w <- probUniformP2 (i - 8) + return UInt8.size * w + v.toNat + /-- ``SLang`` value for the uniform distribution over ``m`` elements, where the number``m`` is the largest power of two that is at most ``n``. -/ -@[extern "prob_UniformP2"] def UniformPowerOfTwoSample (n : ℕ+) : SLang ℕ := - toSLang (PMF.uniformOfFintype (Fin (2 ^ (log 2 n)))) + probUniformP2 (log 2 n) + /-- ``SLang`` functional which executes ``body`` only when ``cond`` is ``false``. diff --git a/SampCert/Samplers/Laplace/Properties.lean b/SampCert/Samplers/Laplace/Properties.lean index 1c0b9fb1..767a9409 100644 --- a/SampCert/Samplers/Laplace/Properties.lean +++ b/SampCert/Samplers/Laplace/Properties.lean @@ -3,6 +3,7 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Jean-Baptiste Tristan -/ +import SampCert.Util.Util import SampCert.Foundations.Basic import SampCert.Samplers.Uniform.Basic import SampCert.Samplers.Bernoulli.Basic @@ -1075,24 +1076,6 @@ lemma partial_geometric_series {p : ENNReal} (HP2 : p < 1) (B : ℕ) : rfl -lemma nat_div_eq_le_lt_iff {a b c : ℕ} (Hc : 0 < c) : a = b / c <-> (a * c ≤ b ∧ b < (a + 1) * c) := by - apply Iff.intro - · intro H - apply And.intro - · apply (Nat.le_div_iff_mul_le Hc).mp - exact Nat.le_of_eq H - · apply (Nat.div_lt_iff_lt_mul Hc).mp - apply Nat.lt_succ_iff.mpr - exact Nat.le_of_eq (id (Eq.symm H)) - · intro ⟨ H1, H2 ⟩ - apply LE.le.antisymm - · apply (Nat.le_div_iff_mul_le Hc).mpr - apply H1 - · apply Nat.lt_succ_iff.mp - simp - apply (Nat.div_lt_iff_lt_mul Hc).mpr - apply H2 - /-- Integer division of a geometric distribution is a geometric distribution -/ @@ -1287,53 +1270,6 @@ lemma geo_div_geo (k n : ℕ) (p : ENNReal) (Hp : p < 1) (Hn : 0 < n) : exact succ_mul k n -/-- -Specialize Euclidean division from ℤ to ℕ --/ -lemma euclidean_division (n : ℕ) {D : ℕ} (HD : 0 < D) : - ∃ q r : ℕ, (r < D) ∧ n = r + D * q := by - exists (n / D) - exists (n % D) - apply And.intro - · exact mod_lt n HD - · apply ((@Nat.cast_inj ℤ).mp) - simp - conv => - lhs - rw [<- EuclideanDomain.mod_add_div (n : ℤ) (D : ℤ)] - -/-- -Euclidiean division is unique --/ -lemma euclidean_division_uniquness (r1 r2 q1 q2 : ℕ) {D : ℕ} (HD : 0 < D) (Hr1 : r1 < D) (Hr2 : r2 < D) : - r1 + D * q1 = r2 + D * q2 <-> (r1 = r2 ∧ q1 = q2) := by - apply Iff.intro - · intro H - cases (Classical.em (r1 = r2)) - · aesop - cases (Classical.em (q1 = q2)) - · aesop - rename_i Hne1 Hne2 - exfalso - - have Contra1 (W X Y Z : ℕ) (HY : Y < D) (HK : W < X) : (Y + D * W < Z + D * X) := by - suffices (D * W < D * X) by - have A : (1 + W ≤ X) := by exact one_add_le_iff.mpr HK - have _ : (D * (1 + W) ≤ D * X) := by exact Nat.mul_le_mul_left D A - have _ : (D + D * W ≤ D * X) := by linarith - have _ : (Y + D * W < D * X) := by linarith - have _ : (Y + D * W < Z + D * X) := by linarith - assumption - exact Nat.mul_lt_mul_of_pos_left HK HD - - rcases (lt_trichotomy q1 q2) with HK' | ⟨ HK' | HK' ⟩ - · exact (LT.lt.ne (Contra1 q1 q2 r1 r2 Hr1 HK') H) - · exact Hne2 HK' - · apply (LT.lt.ne (Contra1 q2 q1 r2 r1 Hr2 HK') (Eq.symm H)) - - · intro ⟨ _, _ ⟩ - simp_all - /-- Equivalence between sampling loops -/ diff --git a/SampCert/Samplers/Uniform/Properties.lean b/SampCert/Samplers/Uniform/Properties.lean index 2938809b..c108628c 100644 --- a/SampCert/Samplers/Uniform/Properties.lean +++ b/SampCert/Samplers/Uniform/Properties.lean @@ -79,7 +79,7 @@ lemma rw_ite (n : PNat) (x : Nat) : (if x < n then (UniformPowerOfTwoSample (2 * n)) x else 0) = if x < n then 1 / 2 ^ log 2 ((2 : PNat) * n) else 0 := by split - rw [probUniformP2_apply] + rw [UniformPowerOfTwoSample_apply] simp only [PNat.mul_coe, one_div] apply double_large_enough trivial @@ -100,7 +100,7 @@ lemma uniformPowerOfTwoSample_autopilot (n : PNat) : = ∑' (i : ℕ), if i < ↑n then UniformPowerOfTwoSample (2 * n) i else 0 := by have X : (∑' (i : ℕ), if decide (↑n ≤ i) = true then UniformPowerOfTwoSample (2 * n) i else 0) + (∑' (i : ℕ), if decide (↑n ≤ i) = false then UniformPowerOfTwoSample (2 * n) i else 0) = 1 := by - have A := probUniformP2_normalizes (2 * n) + have A := UniformPowerOfTwoSample_normalizes (2 * n) have B := @tsum_add_tsum_compl ENNReal ℕ _ _ (fun i => UniformPowerOfTwoSample (2 * n) i) _ _ { i : ℕ | decide (↑n ≤ i) = true} ENNReal.summable ENNReal.summable rw [A] at B clear A @@ -114,7 +114,7 @@ lemma uniformPowerOfTwoSample_autopilot (n : PNat) : trivial apply ENNReal.sub_eq_of_eq_add_rev . have Y := tsum_split_less (fun i => ↑n ≤ i) (fun i => UniformPowerOfTwoSample (2 * n) i) - rw [probUniformP2_normalizes (2 * n)] at Y + rw [UniformPowerOfTwoSample_normalizes (2 * n)] at Y simp at Y clear X by_contra diff --git a/SampCert/Util/Util.lean b/SampCert/Util/Util.lean index e4ff921f..936e3630 100644 --- a/SampCert/Util/Util.lean +++ b/SampCert/Util/Util.lean @@ -252,3 +252,69 @@ theorem tsum_shift'_2 (f : ℕ → ENNReal) : right rw [sum_range_succ] rw [← IH] + +/-- +Specialize Euclidean division from ℤ to ℕ +-/ +lemma euclidean_division (n : ℕ) {D : ℕ} (HD : 0 < D) : + ∃ q r : ℕ, (r < D) ∧ n = r + D * q := by + exists (n / D) + exists (n % D) + apply And.intro + · exact mod_lt n HD + · apply ((@Nat.cast_inj ℤ).mp) + simp + conv => + lhs + rw [<- EuclideanDomain.mod_add_div (n : ℤ) (D : ℤ)] + +/-- +Euclidiean division is unique +-/ +lemma euclidean_division_uniquness (r1 r2 q1 q2 : ℕ) {D : ℕ} (HD : 0 < D) (Hr1 : r1 < D) (Hr2 : r2 < D) : + r1 + D * q1 = r2 + D * q2 <-> (r1 = r2 ∧ q1 = q2) := by + apply Iff.intro + · intro H + cases (Classical.em (r1 = r2)) + · aesop + cases (Classical.em (q1 = q2)) + · aesop + rename_i Hne1 Hne2 + exfalso + + have Contra1 (W X Y Z : ℕ) (HY : Y < D) (HK : W < X) : (Y + D * W < Z + D * X) := by + suffices (D * W < D * X) by + have A : (1 + W ≤ X) := by exact one_add_le_iff.mpr HK + have _ : (D * (1 + W) ≤ D * X) := by exact Nat.mul_le_mul_left D A + have _ : (D + D * W ≤ D * X) := by linarith + have _ : (Y + D * W < D * X) := by linarith + have _ : (Y + D * W < Z + D * X) := by linarith + assumption + exact Nat.mul_lt_mul_of_pos_left HK HD + + rcases (lt_trichotomy q1 q2) with HK' | ⟨ HK' | HK' ⟩ + · exact (LT.lt.ne (Contra1 q1 q2 r1 r2 Hr1 HK') H) + · exact Hne2 HK' + · apply (LT.lt.ne (Contra1 q2 q1 r2 r1 Hr2 HK') (Eq.symm H)) + + · intro ⟨ _, _ ⟩ + simp_all + + +lemma nat_div_eq_le_lt_iff {a b c : ℕ} (Hc : 0 < c) : a = b / c <-> (a * c ≤ b ∧ b < (a + 1) * c) := by + apply Iff.intro + · intro H + apply And.intro + · apply (Nat.le_div_iff_mul_le Hc).mp + exact Nat.le_of_eq H + · apply (Nat.div_lt_iff_lt_mul Hc).mp + apply Nat.lt_succ_iff.mpr + exact Nat.le_of_eq (id (Eq.symm H)) + · intro ⟨ H1, H2 ⟩ + apply LE.le.antisymm + · apply (Nat.le_div_iff_mul_le Hc).mpr + apply H1 + · apply Nat.lt_succ_iff.mp + simp + apply (Nat.div_lt_iff_lt_mul Hc).mpr + apply H2 diff --git a/ffi.cpp b/ffi.cpp index 2a9cebf5..8b3775b3 100644 --- a/ffi.cpp +++ b/ffi.cpp @@ -8,54 +8,15 @@ Authors: Jean-Baptiste Tristan #include #include -#ifdef __APPLE__ - std::random_device generator; -#else - std::mt19937_64 generator(time(NULL)); -#endif +static int urandom = -1; extern "C" lean_object * prob_UniformByte (lean_object * eta) { lean_dec(eta); unsigned char r; - int urandom = open("/dev/urandom", O_RDONLY | O_CLOEXEC); - if (urandom == -1) { - lean_internal_panic("prob_UniformByte: /dev/urandom cannot be opened"); - } read(urandom, &r,1); - close(urandom); return lean_box((size_t) r); } - -extern "C" lean_object * prob_UniformP2(lean_object * a, lean_object * eta) { - lean_dec(eta); - if (lean_is_scalar(a)) { - size_t n = lean_unbox(a); - if (n == 0) { - lean_internal_panic("prob_UniformP2: n == 0"); - } else { - int lz = std::__countl_zero(n); - int bitlength = (8*sizeof n) - lz - 1; - size_t bound = 1 << bitlength; - std::uniform_int_distribution distribution(0,bound-1); - size_t r = distribution(generator); - lean_dec(a); - return lean_box(r); - } - } else { - lean_object * res = lean_usize_to_nat(0); - do { - a = lean_nat_sub(a,lean_box(LEAN_MAX_SMALL_NAT)); - std::uniform_int_distribution distribution(0,LEAN_MAX_SMALL_NAT-1); - size_t rdm = distribution(generator); - lean_object * acc = lean_usize_to_nat(rdm); - res = lean_nat_add(res,acc); - } while(lean_nat_le(lean_box(LEAN_MAX_SMALL_NAT),a)); - lean_object * rem = prob_UniformP2(a,lean_box(0)); - return lean_nat_add(res,rem); - } -} - extern "C" lean_object * prob_Pure(lean_object * a, lean_object * eta) { lean_dec(eta); return a; @@ -85,6 +46,12 @@ extern "C" lean_object * prob_While(lean_object * condition, lean_object * body, } extern "C" lean_object * my_run(lean_object * a) { + if (urandom == -1) { + urandom = open("/dev/urandom", O_RDONLY | O_CLOEXEC); + if (urandom == -1) { + lean_internal_panic("prob_UniformByte: /dev/urandom cannot be opened"); + } + } lean_object * comp = lean_apply_1(a,lean_box(0)); lean_object * res = lean_io_result_mk_ok(comp); return res; diff --git a/lakefile.lean b/lakefile.lean index c4058faa..47a381fa 100644 --- a/lakefile.lean +++ b/lakefile.lean @@ -44,3 +44,9 @@ lean_exe check where root := `SampCertCheck extraDepTargets := #[`libleanffi] moreLinkArgs := #["-L.lake/build/lib", "-lleanffi"] + +lean_exe mk_all where + root := `mk_all + supportInterpreter := true + -- Executables which import `Lake` must set `-lLake`. + weakLinkArgs := #["-lLake"]