diff --git a/libs/lospecs/hlaig.ml b/libs/lospecs/hlaig.ml index bd8e56c0f..94c9473ef 100644 --- a/libs/lospecs/hlaig.ml +++ b/libs/lospecs/hlaig.ml @@ -59,7 +59,9 @@ end (* Assumes circuit inputs have already been appropriately renamed *) module MakeSMTInterface(SMT: SMTInstance) : SMTInterface = struct let circ_equiv (r1 : Aig.reg) (r2 : Aig.reg) (pcond : Aig.node) (inps: (int * int) list) : bool = - assert ((List.compare_length_with r1 0 > 0) && (List.compare_length_with r2 0 > 0)); + if not ((List.compare_length_with r1 0 > 0) && (List.compare_length_with r2 0 > 0)) then + (Format.eprintf "Sizes differ in circ_equiv"; false) + else let bvvars : SMT.bvterm Map.String.t ref = ref Map.String.empty in let rec bvterm_of_node : Aig.node -> SMT.bvterm = diff --git a/src/ecCircuits.ml b/src/ecCircuits.ml index 307e4876e..5cadf6d70 100644 --- a/src/ecCircuits.ml +++ b/src/ecCircuits.ml @@ -8,6 +8,7 @@ open EcAst open EcCoreFol open EcIdent open LDecl +open EcCoreGoal (* -------------------------------------------------------------------- *) module Map = Batteries.Map @@ -47,6 +48,8 @@ let size_of_asize (sz : asize) : int = let size_of_tpsize (sz : tpsize) : int = sz.wordsize * sz.npos +exception CircError of string + (* type deps = ((int * int) * int C.VarRange.t) list *) (* Inputs to circuit functions: Either bitstring of fixed size @@ -87,11 +90,15 @@ let is_bwainput = function let destr_bwinput = function | BWInput (idn, w) -> (idn, w) - | _ -> assert false + | _ -> raise (CircError "destr_bwinput") let destr_bwainput = function | BWAInput (idn, sz) -> (idn, sz) - | _ -> assert false + | _ -> raise (CircError "destr_bwainput") + +let destr_tpinput = function + | BWTInput (idn, sz) -> (idn, sz) + | _ -> raise (CircError "destr_tpinput") let bwinput_of_size (w : width) : cinput = let name = "bw_input" in @@ -101,6 +108,10 @@ let bwainput_of_size ~(nelements : width) ~(wordsize : width) : cinput = let name = "arr_input" in BWAInput (create name, { nelements; wordsize; }) +let bwtpinput_of_size ~(npos : width) ~(wordsize : width) : cinput = + let name = "arr_input" in + BWTInput (create name, { npos; wordsize; }) + (* # of total bits of input *) let size_of_cinput = function | BWInput (_, w) -> w @@ -133,15 +144,15 @@ let is_bwtuple = function let destr_bwcirc = function | BWCirc r -> r - | _ -> assert false + | _ -> raise (CircError "destr_bwcirc") let destr_bwarray = function | BWArray a -> a - | _ -> assert false + | _ -> raise (CircError "destr_bwarray") let destr_bwtuple = function | BWTuple tp -> tp - | _ -> assert false + | _ -> raise (CircError "destr_bwtuple") (* # of total bits of output *) let size_of_circ = function @@ -246,17 +257,20 @@ let match_arg (inp: cinput) (val_: circ) : bool = (* Fully applies a function to a list of constant arguments returning a constant value + THROWS: CircError on failure, should always be caught *) let apply (f: circuit) (args: circ list) : circ = let () = try assert (List.compare_lengths f.inps args = 0); assert (List.for_all2 match_arg f.inps args); - with Assert_failure _ as e -> - Format.eprintf "%s@." (Printexc.get_backtrace ()); - Format.eprintf "Error applying on %s@." (circuit_to_string f); - Format.eprintf "Arguments: @."; - List.iter (Format.eprintf "%s@.") (List.map circ_to_string args); - raise e + with Assert_failure _ as _e -> + let err = Format.asprintf + "Backtrace: %s@.\ + Error applying on %s@.\ + Arguments: @.%a@." (Printexc.get_backtrace ()) + (circuit_to_string f) + (fun fmt args -> List.iter (Format.fprintf fmt "%s@.") args) (List.map circ_to_string args) in + raise (CircError err) in let args = List.combine f.inps args in let map_ = fun (id, i) -> @@ -279,7 +293,16 @@ let apply (f: circuit) (args: circ list) : circ = | BWTInput (_, sz), BWTuple tp -> let it, iw = (i / sz.wordsize), (i mod sz.wordsize) in Option.bind (List.at_opt tp it) (fun l -> List.at_opt l iw) - | _ -> assert false + | _ -> + let err = Format.asprintf "Backtrace: %s@.\ + Error applying on %s@.\ + Arguments: @.%a@.\ + Mismatch between argument types.@." + (Printexc.get_backtrace ()) + (circuit_to_string f) + (fun fmt args -> List.iter (Format.fprintf fmt "%s@.") args) + (List.map circ_to_string (List.snd args)) + in raise (CircError err) ) in match f.circ with @@ -321,7 +344,6 @@ let dist_inputs (c: circuit list) : circuit list = | c::cs -> c::(doit cs (Set.of_list c.inps)) (* -------------------------------------------------------------------- *) -exception CircError of string let width_of_type (env: env) (t: ty) : int = match EcEnv.Circuit.lookup_array_and_bitstring env t with @@ -343,9 +365,15 @@ let shape_of_array_type (env: env) (t: ty) : (int * int) = | Tconstr (p, [et]) -> begin match EcEnv.Circuit.lookup_array_path env p with | Some {size; _} -> size, width_of_type env et - | None -> assert false + | None -> + let err = Format.asprintf "Failed to lookup shape of array type %a@." + (EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env)) t in + raise (CircError err) end - | _ -> assert false + | _ -> + let err = Format.asprintf "Failed to lookup shape of array type %a@." + (EcPrinting.pp_type (EcPrinting.PPEnv.ofenv env)) t in + raise (CircError err) (* Given an EC type with the correct bindings returns a circuit input matching the shape of that type *) @@ -355,10 +383,9 @@ let cinput_of_type ?(idn: ident option) (env: env) (t: ty) : cinput = | Some idn -> idn | None -> create name in - match destr_array_type env t with + match EcEnv.Circuit.lookup_array_and_bitstring env t with | None -> BWInput (idn, width_of_type env t) - | Some (nelements, t) -> - let wordsize = width_of_type env t in + | Some ({size=nelements}, {size=wordsize}) -> BWAInput (idn, { nelements; wordsize }) (* given f(inps1), g(inps2) returns h(inps1,inps2) = f(a) @ g(b) @@ -368,13 +395,13 @@ let circuit_concat (c: circuit) (d: circuit) : circuit = match c.circ, d.circ with | BWCirc ccirc, BWCirc dcirc -> {circ=BWCirc(ccirc @ dcirc); inps=c.inps} - | _ -> assert false + | _ -> raise (CircError "concat") else let d = if inputs_indep [c;d] then d else fresh_inputs d in match c.circ, d.circ with | BWCirc ccirc, BWCirc dcirc -> {circ=BWCirc(ccirc @ dcirc); inps=c.inps @ d.inps} - | _ -> assert false + | _ -> raise (CircError "concat") (* Same as above but concatenates arrays of bitwords *) let circuit_array_concat (c: circuit) (d: circuit) : circuit = @@ -382,7 +409,7 @@ let circuit_array_concat (c: circuit) (d: circuit) : circuit = match c.circ, d.circ with | BWArray carr, BWArray darr -> {circ=BWArray(Array.concat [carr; darr]); inps=c.inps @ d.inps} - | _ -> assert false + | _ -> raise (CircError "array concat") let (++) : circuit -> circuit -> circuit = circuit_concat let (+@) : circuit -> circuit -> circuit = circuit_array_concat @@ -399,7 +426,8 @@ let circuit_array_aggregate (c: circuit list) : circuit = (* To be removed and replaced by a combination of other operations *) let circuit_bwarray_set ~(nelements : width) ~(wordsize : width) (i: int) : circuit = - assert (nelements > i); + (* Index guarantee should come from EC *) + (* assert (nelements > i); *) let arr_inp = BWAInput (create "arr_input", { nelements; wordsize; }) in let bw_inp = BWInput (create "bw_input", wordsize) in let arr_id = (ident_of_cinput arr_inp).id_tag in @@ -411,7 +439,7 @@ let circuit_bwarray_set ~(nelements : width) ~(wordsize : width) (i: int) : circ (* Same as above *) let circuit_bwarray_get ~(nelements : width) ~(wordsize : width) (i: int) : circuit = - assert (nelements > i); + (* assert (nelements > i); *) let arr_inp = BWAInput (create "arr_input", { nelements; wordsize; }) in let out = List.init wordsize (fun j -> C.input ((ident_of_cinput arr_inp).id_tag, j + wordsize*i)) in {circ=BWCirc (out); inps=[arr_inp]} @@ -420,18 +448,33 @@ let circuit_bwarray_get ~(nelements : width) ~(wordsize : width) (i: int) : circ (* Function composition for circuits *) (* Reduces to application if arguments are 0-ary *) +(* + THROWS: CircError on failure (from apply calls) +*) let compose (f: circuit) (args: circuit list) : circuit = (* assert (List.compare_lengths f.inps args = 0); *) (* Length comparison should be done in apply *) let args = dist_inputs args in - {circ=apply f (List.map (fun c -> c.circ) args); - inps=List.fold_right (@) (List.map (fun c -> c.inps) args) []} + try + {circ=apply f (List.map (fun c -> c.circ) args); + inps=List.fold_right (@) (List.map (fun c -> c.inps) args) []} + with CircError err -> + raise (CircError ("On compose call, apply failed with err:\n" ^ err) ) (* FIXME: convert computation to return BI.zint *) let compute ~(sign:bool) (f: circuit) (r: BI.zint list) : int = - assert (List.compare_lengths f.inps r = 0); + (* FIXME: can we remove the parenthesis around the try/with block? *) + (try + assert (List.compare_lengths f.inps r = 0) + with Assert_failure _ -> + let err = Format.asprintf + "Wrong # of arguments (%d provided, %d expected) for compute" + (List.length r) + (List.length f.inps) + in + raise (CircError err)); let vs = List.map2 (fun inp r -> let _, size = destr_bwinput inp in BWCirc(C.of_bigint_all ~size (BI.to_zt r)) @@ -441,8 +484,9 @@ let compute ~(sign:bool) (f: circuit) (r: BI.zint list) : int = let res = List.map (function | {C.gate = C.False; C.id = id} -> if id >= 0 then false else true - | _ -> assert false + | _ -> raise (CircError "Non-constant result in compute (not fully applied?)") ) res in + (* conversion functions need to be reworked FIXME *) if sign then C.sint_of_bools res else @@ -479,22 +523,26 @@ let circuit_flatten (c: circuit) : circuit = | BWCirc _ -> c | BWArray a -> {circ=BWCirc(Array.fold_right (@) a []); inps=c.inps} - | BWTuple _ -> assert false + | BWTuple _ -> raise (CircError "Cannot flatten tuple") (* Chunks a bitstring into an array of bitstrings, each of size w *) let circuit_bw_split (c: circuit) (w: int) : circuit = match c.circ with - | BWArray _ -> assert false - | BWTuple _ -> assert false + | BWArray _ -> raise (CircError "Cannot chunk array") + | BWTuple _ -> raise (CircError "Cannot chunk tuple") | BWCirc r -> let nk = List.length r in - assert (nk mod w = 0); - let rs = List.chunkify w r |> Array.of_list in - {circ=BWArray rs; inps = c.inps} + if (nk mod w = 0) then + let rs = List.chunkify w r |> Array.of_list in + {circ=BWArray rs; inps = c.inps} + else + let err = Format.asprintf "Size of circ (%d) not evenly divided by chunk size (%d)" nk w in + raise (CircError err) (* Zero-extends a bitstring *) let circuit_bw_zeroextend (c: circuit) (w: int) : circuit = - assert(size_of_circ c.circ <= w); + (* FIXME: default behaviour when size of extenion < cur size or EC catches that case? *) + (* assert(size_of_circ c.circ <= w); *) let r = destr_bwcirc c.circ in let zs = List.init (w - size_of_circ c.circ) (fun _ -> C.true_) in {c with circ = BWCirc(r @ zs)} @@ -509,14 +557,17 @@ let bus_of_cinputs (inps: cinput list) : circ list * cinput = let rec doit (r: C.reg) (cs: cinput list) : circ list = match r, cs with | [], [] -> [] - | [], _ -> assert false + | [], _ | _, [] -> assert false - | _, BWInput (_, w)::cs -> let r1, r2 = List.takedrop w r in + | r, BWInput (_, w)::cs -> let r1, r2 = List.takedrop w r in (BWCirc r1)::(doit r2 cs) - | _, BWAInput (_, sz)::cs -> let r1, r2 = List.takedrop (size_of_asize sz) r in + | r, BWAInput (_, sz)::cs -> let r1, r2 = List.takedrop (size_of_asize sz) r in let r1 = List.chunkify sz.wordsize r1 |> Array.of_list in (BWArray r1)::(doit r2 cs) - | _ -> assert false (* FIXME: This catches the tuple case, check if doesnt cause issues *) + | r, BWTInput (_, sz)::cs -> + let r1, r2 = List.takedrop (size_of_tpsize sz) r in + let r1 = List.chunkify sz.wordsize r1 in + (BWTuple r1)::(doit r2 cs) in doit r inps, BWInput (idn, bsize) @@ -525,12 +576,8 @@ let bus_of_cinputs (inps: cinput list) : circ list * cinput = let circuit_aggregate_inps (c: circuit) : circuit = match c.inps with | [] -> c - | _inps -> - (* Format.eprintf "Previous inputs: "; *) - (* List.iter (Format.eprintf "%s |") (List.map (cinput_to_string) inps); *) - (* Format.eprintf "@."; *) - let circs, inp = bus_of_cinputs c.inps in - (* Format.eprintf "Aggregating inputs to input: %s@." (cinput_to_string inp); *) + | inps -> + let circs, inp = bus_of_cinputs inps in {circ=apply c circs; inps=[inp]} (* @@ -539,7 +586,8 @@ let circuit_aggregate_inps (c: circuit) : circuit = i = start_index *) let circuit_array_sliceget ~(wordsize : width) (arr_sz : width) (out_sz: width) (i: int) : circuit = - assert (arr_sz >= out_sz + i); + (* FIXME: Should be caught on EC side *) + (* assert (arr_sz >= out_sz + i); *) let arr_inp = bwainput_of_size ~nelements:arr_sz ~wordsize in let arr_id = (ident_of_cinput arr_inp).id_tag in let out = Array.init out_sz (fun ja -> @@ -555,7 +603,8 @@ let circuit_array_sliceget ~(wordsize : width) (arr_sz : width) (out_sz: width) i = start_index *) let circuit_array_sliceset ~(wordsize : width) (arr_sz : width) (out_sz: width) (i: int) : circuit = - assert (arr_sz >= out_sz + i); + (* FIXME: Should be caught on EC side *) + (* assert (arr_sz >= out_sz + i); *) let arr_inp = bwainput_of_size ~nelements:arr_sz ~wordsize in let arr_id = (ident_of_cinput arr_inp).id_tag in let new_arr_inp = bwainput_of_size ~nelements:out_sz ~wordsize in @@ -570,7 +619,8 @@ let circuit_array_sliceset ~(wordsize : width) (arr_sz : width) (out_sz: width) (* To be removed when we have external op bindings *) let circuit_bwarray_slice_get (arr_sz: width) (el_sz: width) (acc_sz: int) (i: int) : circuit = - assert (arr_sz*el_sz >= i + acc_sz); + (* FIXME: Should be caught on EC side *) + (* assert (arr_sz*el_sz >= i + acc_sz); *) let arr_inp = bwainput_of_size ~nelements:arr_sz ~wordsize:el_sz in let arr_id = (ident_of_cinput arr_inp).id_tag in let out = List.init acc_sz (fun j -> C.input (arr_id, i+j)) in @@ -578,7 +628,8 @@ let circuit_bwarray_slice_get (arr_sz: width) (el_sz: width) (acc_sz: int) (i: i (* To be removed when we have external op bindings *) let circuit_bwarray_slice_set (arr_sz: width) (el_sz: width) (acc_sz: int) (i: int) : circuit = - assert (arr_sz*el_sz >= i + acc_sz); + (* FIXME: Should be caught on EC side *) + (* assert (arr_sz*el_sz >= i + acc_sz); *) let bw_inp = bwinput_of_size acc_sz in let bw_id = (ident_of_cinput bw_inp).id_tag in let arr_inp = bwainput_of_size ~nelements:arr_sz ~wordsize:el_sz in @@ -599,30 +650,19 @@ let circuit_tuple_proj (c: circuit) (i: int) : circuit = | BWTuple tp -> begin try {c with circ=BWCirc (List.at tp i)} with Invalid_argument e -> - Format.eprintf "Proj outside tuple size (should never happen)@."; - assert false + let err = Format.sprintf "Projection at component %d outside tuple size (%d)@." i (List.length tp) in + raise (CircError err) end - | _ -> assert false + | _ -> raise (CircError "Projection on non-tuple type") let circuit_ueq (c: circuit) (d: circuit) : circuit = match c.circ, d.circ with | BWCirc r1, BWCirc r2 -> {circ= BWCirc[C.bvueq r1 r2]; inps=c.inps @ d.inps} - | _ -> failwith "Implement other cases for circuit_ueq" - -(* Input for splitting function w.r.t. dependencies *) -let input_of_tdep (n: int) (bs: int Set.t) : _ * cinput = - let temp_symbol = "tdep_ident" in - let m = Set.cardinal bs in - let id = create temp_symbol in - let map_ = Set.to_seq bs |> List.of_seq in - let map_ = List.map (fun a -> (n, a)) map_ in - let map_ = List.combine map_ (List.init m (fun i -> C.input (id.id_tag, i))) in - let map_ = Map.of_seq (List.to_seq map_) in - map_, BWInput (id, m) - -let inputs_of_tdep (td: HL.tdeps) : _ * cinput list = - Map.foldi (fun n bs (map_, inps) -> let map_2, inp = input_of_tdep n bs in - (Map.union map_ map_2, inp::inps)) td (Map.empty, []) + | BWArray a1, BWArray a2 -> let elems = Array.map2 C.bvueq a1 a2 in + {circ= BWCirc[C.ands (Array.to_list elems)]; inps=c.inps @ d.inps} + | BWTuple t1, BWTuple t2 -> let elems = List.map2 C.bvueq t1 t2 in + {circ= BWCirc[C.ands elems]; inps=c.inps @ d.inps} + | _ -> raise (CircError "Mismatched types for ueq") (* f : BV1 -> BV2 @@ -630,35 +670,50 @@ let inputs_of_tdep (td: HL.tdeps) : _ * cinput list = returns: BV2 Array = mapping f over a *) let circuit_map (f: circuit) (a: circuit) : circuit = - let a, inps = destr_bwarray a.circ, a.inps in + let a, inps = try + destr_bwarray a.circ, a.inps + with CircError _ -> + raise (CircError "Argument to circuit map is not bwarray") + in let r = Array.map (fun r -> apply f [BWCirc r]) a in let r = Array.map (destr_bwcirc) r in {circ = BWArray r; inps} let circuit_split ?(perm: (int -> int) option) (f: circuit) (lane_in_w: int) (lane_out_w: int) : circuit list = - assert (List.length f.inps = 1); - let r = destr_bwcirc f.circ in - let inp_t, inp_w = List.hd f.inps |> destr_bwinput in - (*assert ((inp_w mod lane_in_w = 0) && (List.length r mod lane_out_w = 0));*) - let rs = List.chunkify (lane_out_w) r in - let rs = match perm with - | Some perm -> List.filteri_map (fun i _ -> let idx = (perm i) in - if idx < 0 || idx > (List.length rs) then None else - Some (List.nth rs (idx))) rs - | None -> rs - in - let rs = List.mapi (fun lane_idx lane_circ -> - let id = create ("split_" ^ (string_of_int lane_idx)) in - let map_ = (function - | (v, j) when v = inp_t.id_tag - && (0 <= (j - (lane_idx*lane_in_w)) && (j-(lane_in_w*lane_idx)) < lane_in_w) - -> Some (C.input (id.id_tag, j - (lane_idx*lane_in_w))) - | _ -> None - ) in - let circ = BWCirc(C.maps map_ lane_circ) in - {circ; inps=[BWInput(id, lane_in_w)]} - ) rs in - rs + (* FIXME: Allow bdep for multiple inputs? *) + if (List.length f.inps <> 1) + then raise (CircError "Multi input circuit split not supported") + else + + try + let r = destr_bwcirc f.circ in + let inp_t, inp_w = List.hd f.inps |> destr_bwinput in + (*assert ((inp_w mod lane_in_w = 0) && (List.length r mod lane_out_w = 0));*) + let rs = List.chunkify (lane_out_w) r in + let rs = match perm with + | Some perm -> List.filteri_map (fun i _ -> let idx = (perm i) in + if idx < 0 || idx > (List.length rs) then None else + Some (List.nth rs (idx))) rs + | None -> rs + in + let rs = List.mapi (fun lane_idx lane_circ -> + let id = create ("split_" ^ (string_of_int lane_idx)) in + let map_ = (function + | (v, j) when v = inp_t.id_tag + && (0 <= (j - (lane_idx*lane_in_w)) && (j-(lane_in_w*lane_idx)) < lane_in_w) + -> Some (C.input (id.id_tag, j - (lane_idx*lane_in_w))) + | _ -> None + ) in + let circ = BWCirc(C.maps map_ lane_circ) in + {circ; inps=[BWInput(id, lane_in_w)]} + ) rs in + rs + with + | CircError "destr_bwcirc" -> + raise (CircError "Cannot split array or tuple") + | CircError "destr_bwinput" -> + raise (CircError "Cannot split non bitstring input") + (* Partitions into blocks of type n -> m *) let circuit_mapreduce ?(perm: (int -> int) option) (c: circuit) (n:int) (m:int) : circuit list = @@ -673,15 +728,19 @@ let circuit_mapreduce ?(perm: (int -> int) option) (c: circuit) (n:int) (m:int) circuit_aggregate_inps c else c in - let r = destr_bwcirc c.circ in + let r = try destr_bwcirc c.circ + with CircError _ -> raise (CircError "Cannot mapreduce on non-bitstring return type") + in let deps = HL.deps r in let deps = HL.split_deps m deps in - Format.eprintf "%d@." (List.length deps); + (* Format.eprintf "%d@." (List.length deps); *) (* Format.eprintf "%a@." (fun fmt -> HL.pp_bdeps fmt) deps; *) - assert (HL.block_list_indep deps); - assert (List.for_all (HL.check_dep_width n) (List.snd deps)); + if not ((HL.block_list_indep deps) && (List.for_all (HL.check_dep_width n) (List.snd deps))) + then + raise (CircError "Failed mapreduce split (dependency split condition not true)") + else (* assert ((List.sum (List.map size_of_cinput c.inps)) mod n = 0);*) Format.eprintf "[W] Dependency analysis complete after %f seconds@." @@ -702,15 +761,20 @@ let circuit_mapreduce ?(perm: (int -> int) option) (c: circuit) (n:int) (m:int) -> {circ=BWCirc (C.uextend ~size:m r); inps=[BWInput (idn, n)]} | {circ=BWCirc r; inps=[]} -> {circ=BWCirc (C.uextend ~size:m r); inps=[const_inp]} - | c -> Format.eprintf "Failed for %s@." (circuit_to_string c) ; assert false) + | c -> let err = Format.sprintf "Failed for %s@." (circuit_to_string c) in + raise (CircError err)) cs (* Build a circuit function that takes an input n bits wide and permutes it in blocks of w bits by the permutation given by f Expects that w | n and that f|[n/w] is a bijection *) let circuit_permutation (n: int) (w: int) (f: int -> int) : circuit = - assert (n mod w = 0); - assert ( List.init (n/w) f |> Set.of_list |> Set.map f |> Set.cardinal = (n/w)); + if (n mod w <> 0) then + let err = Format.sprintf "In circuit permutation, block size (%d) does not divide circuit size (%d)@." w n in + raise (CircError err) + else + (* FIXME: Permutation check should come from EC *) + (* assert ( List.init (n/w) f |> Set.of_list |> Set.map f |> Set.cardinal = (n/w)); *) let inp = bwinput_of_size n in let inp_circ = circ_ident inp in let cblocks = destr_bwcirc inp_circ.circ in @@ -739,9 +803,9 @@ let circuit_from_spec_ (env: env) (p : path) : C.reg list -> C.reg = match EcEnv.Circuit.lookup_circuit_path env p with | Some circuit -> (fun regs -> C.circuit_of_specification regs circuit) - | None -> Format.eprintf "No operator for path: %s@." - (let a,b = EcPath.toqsymbol p in List.fold_right (fun a b -> a ^ "." ^ b) a b); - assert false + | None -> let err = Format.sprintf "No operator for path: %s@." + (let a,b = EcPath.toqsymbol p in List.fold_right (fun a b -> a ^ "." ^ b) a b) in + raise (CircError err) let circuit_from_spec (env: env) (p : path) : circuit = @@ -783,6 +847,7 @@ module BaseOps = struct let c1 = C.reg ~size ~name:id1.id_tag in let c2 = C.reg ~size ~name:id2.id_tag in {circ = BWCirc(C.add_dropc c1 c2); inps = [BWInput(id1, size); BWInput(id2, size)]} + | Some { kind = `Sub size } -> let id1 = EcIdent.create (temp_symbol) in let id2 = EcIdent.create (temp_symbol) in @@ -915,25 +980,25 @@ module BaseOps = struct { circ = BWCirc([C.sge c2 c1]); inps=[BWInput(id1, size); BWInput(id2, size)]} | Some { kind = `Extend (size, out_size, false) } -> - assert (size <= out_size); + (* assert (size <= out_size); *) let id1 = EcIdent.create (temp_symbol) in let c1 = C.reg ~size ~name:id1.id_tag in {circ = BWCirc(C.uextend ~size:out_size c1); inps = [BWInput (id1, size)]} | Some { kind = `Extend (size, out_size, true) } -> - assert (size <= out_size); + (* assert (size <= out_size); *) let id1 = EcIdent.create (temp_symbol) in let c1 = C.reg ~size ~name:id1.id_tag in {circ = BWCirc(C.sextend ~size:out_size c1); inps = [BWInput (id1, size)]} | Some { kind = `Truncate (size, out_sz) } -> - assert (size >= out_sz); + (* assert (size >= out_sz); *) let id1 = EcIdent.create (temp_symbol) in let c1 = C.reg ~size:out_sz ~name:id1.id_tag in { circ = BWCirc(c1); inps=[BWInput (id1, size)]} | Some { kind = `Concat (sz1, sz2, szo) } -> - assert (sz1 + sz2 = szo); + (* assert (sz1 + sz2 = szo); *) let id1 = EcIdent.create (temp_symbol) in let c1 = C.reg ~size:sz1 ~name:id1.id_tag in let id2 = EcIdent.create (temp_symbol) in @@ -941,13 +1006,13 @@ module BaseOps = struct { circ = BWCirc(c1 @ c2); inps=[BWInput (id1, sz1); BWInput (id2, sz2)]} | Some { kind = `A2B ((w, n), m)} -> - assert (n * w = m); + (* assert (n * w = m); *) let id1 = EcIdent.create temp_symbol in let c1 = C.reg ~size:m ~name:id1.id_tag in { circ = BWCirc(c1); inps = [BWAInput (id1, { nelements = n; wordsize = w })]} | Some { kind = `B2A (m, (w, n))} -> - assert (n * w = m); + (* assert (n * w = m); *) let id1 = EcIdent.create temp_symbol in let c1 = C.reg ~size:m ~name:id1.id_tag in let c1 = List.chunkify w c1 |> Array.of_list in @@ -969,22 +1034,18 @@ module ArrayOps = struct end let circ_equiv ?(strict=false) (f: circuit) (g: circuit) (pcond: circuit option) : bool = - let f, g = - if strict then (assert(circ_shape_equal f.circ g.circ); f, g) + let fg = + if strict then begin + if circ_shape_equal f.circ g.circ then Some (f, g) + else None + end else if size_of_circ f.circ < size_of_circ g.circ then - circuit_bw_zeroextend f (size_of_circ g.circ), g else - f, circuit_bw_zeroextend g (size_of_circ f.circ) + Some (circuit_bw_zeroextend f (size_of_circ g.circ), g) else + Some (f, circuit_bw_zeroextend g (size_of_circ f.circ)) in - (* if (List.is_empty f.inps) || (List.is_empty g.inps) then *) - (* if (List.is_empty f.inps) && (List.is_empty g.inps) then *) - (* match f.circ, g.circ with *) - (* | BWCirc r1, BWCirc r2 -> r1 = r2 *) - (* | BWArray a1,BWArray a2 -> a1 = a2 *) - (* | BWTuple t1, BWTuple t2 -> t1 = t2 *) - (* | _ -> false *) - (* else *) - (* false *) - (* else *) + if Option.is_none fg then false + else + let f, g = Option.get fg in (* FIXME: more general input unification procedure *) let pcond = match pcond with | Some pcond -> pcond @@ -998,9 +1059,14 @@ let circ_equiv ?(strict=false) (f: circuit) (g: circuit) (pcond: circuit option) begin (List.for_all2 (==) fcirc gcirc) || let module B = (val HL.makeBWZinterface ()) in - B.circ_equiv fcirc gcirc (List.hd pccirc) - (List.map (fun inp -> let a, b = destr_bwinput inp in - (a.id_tag, b)) f.inps) + begin + try + B.circ_equiv fcirc gcirc (List.hd pccirc) + (List.map (fun inp -> let a, b = destr_bwinput inp in + (a.id_tag, b)) f.inps) + with CircError "destr_bwinput" -> + raise (CircError "Non-bitstring input in equiv call") + end (* Assuming no array inputs for now *) end | _ -> assert false @@ -1061,23 +1127,33 @@ let circuit_of_form let int_of_form (f: form) : zint = match f.f_node with | Fint i -> i - | _ -> destr_int @@ EcCallbyValue.norm_cbv EcReduction.full_red hyps f + | _ -> begin + try destr_int @@ EcCallbyValue.norm_cbv EcReduction.full_red hyps f + with DestrError "int" -> + let err = Format.asprintf "Failed to reduce form | %a | to integer" + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in + raise (CircError err) + end in match f_.f_node with - (* hardcoding size for now FIXME *) - | Fint z -> assert false - (* env, {circ = BWCirc(C.of_bigint ~size:256 (to_zt z)); inps = []} *) - (* failwith "Add logic to deal with ints (maybe force conversion?)" *) - (* hlenv, C.of_bigint ~size:256 (EcAst.BI.to_zt z) *) + | Fint z -> raise (CircError "Translation encountered unexpected integer value") + + (* Assumes no quantifier bindings/new inputs within if *) | Fif (c_f, t_f, f_f) -> let hyps, c_c = doit cache hyps c_f in let hyps, t_c = doit cache hyps t_f in let hyps, f_c = doit cache hyps f_f in - let () = assert (List.length (destr_bwcirc c_c.circ) = 1) in - let () = assert (List.is_empty c_c.inps) in - let () = assert (List.is_empty t_c.inps) in - let () = assert (List.is_empty f_c.inps) in + + if (try (List.length (destr_bwcirc c_c.circ) <> 1) + with CircError "destr_bwcirc" -> + raise (CircError "Condition circuit should output a bitstring of size 1") + ) then raise (CircError "Condition circuit output size too big") + + else + (* let () = assert (List.is_empty c_c.inps) in *) + (* let () = assert (List.is_empty t_c.inps) in *) + (* let () = assert (List.is_empty f_c.inps) in *) let c_c = List.hd (destr_bwcirc c_c.circ) in begin match t_c.circ, f_c.circ with @@ -1089,16 +1165,16 @@ let circuit_of_form | BWArray t_cs, BWArray f_cs when (Array.length t_cs = Array.length f_cs) -> hyps, { circ = BWArray (Array.map2 (C.ite c_c) t_cs f_cs); - inps = []; (* FIXME: check if we want to allow bindings inside ifs *) + inps = []; } | BWTuple t_tp, BWTuple f_tp when (List.compare_lengths t_tp f_tp = 0) -> hyps, { circ = BWTuple (List.map2 (C.ite c_c) t_tp f_tp); inps = []; } - | _ -> assert false + | _ -> raise (CircError "Type mismatch between conditional arms") + (* EC should prevent this as equal EC types ==> equal circuit types *) end - (* Assumes no quantifier bindings/new inputs within if *) | Flocal idn -> begin match Map.find_opt idn cache with | Some (inp, circ) -> @@ -1139,7 +1215,8 @@ let circuit_of_form | Some `False -> hyps, {circ = BWCirc([C.false_]); inps=[]} | _ -> - Format.eprintf "%s@." (EcPath.tostring pth); failwith "Unsupported op kind" + let err = Format.sprintf "Unsupported op kind%s@." (EcPath.tostring pth) in + raise (CircError err) end in op_cache := Mp.add pth circ !op_cache; @@ -1150,13 +1227,13 @@ let circuit_of_form let (f, fs) = EcCoreFol.destr_app f_ in let hyps, res = (* Assuming correct types coming from EC *) - (* FIXME: add typechecking here ? *) + (* FIXME: Add some extra info about errors when something here throws *) match EcEnv.Circuit.reverse_operator env @@ (EcCoreFol.destr_op f |> fst) with | `Array ({ size }, `Get) :: _ -> let hyps, res = match fs with | [arr; i] -> let i = int_of_form i in - let (_, t) = destr_array_type env arr.f_ty |> Option.get in + let (_, t) = Option.get_exn (destr_array_type env arr.f_ty) (CircError "Array get type error") in let w = width_of_type env t in let hyps, arr = doit cache hyps arr in hyps, compose (circuit_bwarray_get ~nelements:size ~wordsize:w (BI.to_int i)) [arr] @@ -1173,50 +1250,72 @@ let circuit_of_form | _ -> raise (CircError "set") in hyps, res | `Array ({ size }, `OfList) :: _-> - let _, { nelements = n; wordsize = w } = destr_bwainput @@ cinput_of_type env f_.f_ty in - assert (n = size); - (* FIXME: have an actual way to get sizes without creating new idents *) - let wtn, vs = match fs with - | [wtn; vs] -> wtn, vs - | _ -> assert false (* should only be two arguments to of_list *) + let n, w = + match EcEnv.Circuit.lookup_array_and_bitstring env f_.f_ty with + | Some ({size=asize}, {size=bwsize}) -> asize, bwsize + | None -> raise (CircError "Array of_list type error (wrong binding?)") + in + let dfl, vs = match fs with + | [dfl; vs] -> dfl, vs + | _ -> assert false + (* This should be caught by the EC typecheck/bindings so never actually happens *) + in + let vs = try EcCoreFol.destr_list vs + with DestrError _ -> raise (CircError "Failed to destructure list argument to array of_list") in - let vs = EcCoreFol.destr_list vs in let hyps, vs = List.fold_left_map (doit cache) hyps vs in - begin match EcCoreFol.is_witness wtn with + begin match EcCoreFol.is_witness dfl with | false -> - let hyps, wtn = doit cache hyps wtn in - assert(List.is_empty wtn.inps && List.for_all (fun c -> List.is_empty c.inps) vs); - let vs = List.map (fun c -> destr_bwcirc c.circ) vs in - let wtn = destr_bwcirc wtn.circ in - let r = Array.init n (fun i -> List.nth_opt vs i |> Option.default wtn) in - hyps, {circ = BWArray r; inps = []} + let hyps, dfl = doit cache hyps dfl in + if not (List.is_empty dfl.inps && List.for_all (fun c -> List.is_empty c.inps) vs) then + raise (CircError "Non-constant circuits in of_list not supported") + else + begin try + let vs = List.map (fun c -> destr_bwcirc c.circ) vs in + let dfl = destr_bwcirc dfl.circ in + let r = Array.init n (fun i -> List.nth_opt vs i |> Option.default dfl) in + hyps, {circ = BWArray r; inps = []} + with CircError "destr_bwcirc" _ -> + raise (CircError "BWCirc destruct error in array of_list ") + end | true -> - assert (List.compare_length_with vs n = 0); - assert (List.for_all (fun c -> List.is_empty c.inps) vs); - let vs = List.map (fun c -> destr_bwcirc c.circ) vs in - let r = Array.of_list vs in - hyps, {circ=BWArray r; inps=[]} + if not (List.compare_length_with vs n = 0) then + raise (CircError "Insufficient list length for of_list with default = witness") + else + if not (List.for_all (fun c -> List.is_empty c.inps) vs) then + raise (CircError "Non-constant circuits in of_list not supported") + else + begin try + let vs = List.map (fun c -> destr_bwcirc c.circ) vs in + let r = Array.of_list vs in + hyps, {circ=BWArray r; inps=[]} + with CircError _ -> + raise (CircError "BWCirc destruct error in array of_list ") + end end | `Bitstring ({ size }, `OfInt) :: _ -> let i = match fs with | f :: _ -> int_of_form f - | _ -> assert false + | _ -> assert false (* Should never happen, caught in EC typecheck/bindings *) in hyps, { circ = BWCirc (C.of_bigint_all ~size (to_zt i)); inps = [] } | `BvOperator ({ kind = `Extract (size, out_sz) }) :: _ -> - assert (size >= out_sz); + (* assert (size >= out_sz); *) + (* Should never happen, caught in EC typecheck/bindings *) let c1, b = match fs with | [c; f] -> c, int_of_form f - | _ -> assert false + | _ -> assert false (* Should never happen, caught in EC typecheck/bindings *) in let hyps, c1 = doit cache hyps c1 in - let c = destr_bwcirc c1.circ in + let c = try destr_bwcirc c1.circ + with CircError _ -> raise "BWCirc destr error at bvextract" + in let c = List.take out_sz (List.drop (to_int b) c) in hyps, { circ = BWCirc(c); inps=c1.inps } | `BvOperator ({kind = `Init (size)}) :: _ -> let f = match fs with | [f] -> f - | _ -> assert false + | _ -> assert false (* Should never happen, caught in EC typecheck/bindings *) in let fs = List.init size (fun i -> fapply_safe f [f_int (of_int i)]) in (* List.iter (Format.eprintf "|%a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env))) fs; *) @@ -1225,35 +1324,36 @@ let circuit_of_form | `BvOperator ({kind = `Get (size)}) :: _ -> let bv, i = match fs with | [bv; i] -> bv, int_of_form i |> to_int - | _ -> assert false + | _ -> assert false (* Should never happen, caught in EC typecheck/bindings *) in - assert (i < size); + (* assert (i < size); *) + (* Should never happen, caught in EC typecheck/bindings *) let hyps, bv = doit cache hyps bv in - let bv_base = destr_bwcirc bv.circ in + let bv_base = try destr_bwcirc bv.circ + with CircError _ -> raise (CircError "BWCirc destr error at bvget") + in hyps, {bv with circ = BWCirc([List.nth bv_base i])} | `BvOperator ({kind = `AInit (arr_sz, bw_sz)}) :: _ -> let f = match fs with | [f] -> f - | _ -> assert false + | _ -> assert false (* Should never happen, caught in EC typecheck/bindings *) in let fs = List.init arr_sz (fun i -> fapply_safe f [f_int (of_int i)]) in (* List.iter (Format.eprintf "|%a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env))) fs; *) let hyps, fs = List.fold_left_map (doit cache) hyps fs in - assert (List.for_all (fun c -> List.is_empty c.inps) fs); + if not (List.for_all (fun c -> List.is_empty c.inps) fs) then + raise (CircError "Circut Input problem at array init") + else + begin try hyps, {circ = BWArray(Array.of_list (List.map (fun c -> destr_bwcirc c.circ) fs)); inps=[]} + with CircError _ -> raise (CircError "Array elements in init should be bitstrings") + end - (* begin *) - (* match f.f_node with *) - (* | Fapp _ -> Format.eprintf "Its an Fapp@."; assert false *) - (* | Fquant (Llambda, _, _) -> Format.eprintf "Its an Flambda@."; assert false *) - (* | Fop _ -> Format.eprintf "Its an Fop@."; assert false *) - (* | _ -> Format.eprintf "Its something else @."; assert false *) - (* end *) | `BvOperator ({kind = `Map (sz1, sz2, asz)}) :: _ -> let f, a = match fs with | [f; a] -> f, a - | _ -> assert false + | _ -> assert false (* Should never happen, caught in EC typecheck/bindings *) in let hyps, f = doit cache hyps f in let hyps, a = doit cache hyps a in @@ -1262,7 +1362,7 @@ let circuit_of_form | `BvOperator ({kind = `ASliceGet ((arr_sz, sz1), sz2)}) :: _ -> let arr, i = match fs with | [arr; i] -> arr, int_of_form i - | _ -> assert false + | _ -> assert false (* Should never happen, caught in EC typecheck/bindings *) in let op = circuit_bwarray_slice_get arr_sz sz1 sz2 (to_int i) in let hyps, arr = doit cache hyps arr in @@ -1271,7 +1371,7 @@ let circuit_of_form | `BvOperator ({kind = `ASliceSet ((arr_sz, sz1), sz2)}) :: _ -> let arr, i, bv = match fs with | [arr; i; bv] -> arr, int_of_form i, bv - | _ -> assert false + | _ -> assert false (* Should never happen, caught in EC typecheck/bindings *) in let op = circuit_bwarray_slice_set arr_sz sz1 sz2 (to_int i) in let hyps, arr = doit cache hyps arr in @@ -1284,11 +1384,15 @@ let circuit_of_form let hyps, c2 = doit cache hyps f2 in begin match c1.circ, c2.circ with | BWCirc r1, BWCirc r2 -> - assert (List.compare_lengths r1 r2 = 0); + (* assert (List.compare_lengths r1 r2 = 0); *) + (* Should never happen, caught in EC typecheck/bindings *) hyps, {circ = BWCirc([C.bvueq r1 r2]); inps=c1.inps @ c2.inps} (* FIXME: check inps here *) | BWArray a1, BWArray a2 -> - assert (Array.length a1 = Array.length a2); - assert (Array.for_all2 (fun a b -> (List.compare_lengths a b) = 0) a1 a2); + (* assert (Array.for_all2 (fun a b -> (List.compare_lengths a b) = 0) a1 a2); *) + (* Should never happen, caught in EC typecheck/bindings *) + if not (Array.length a1 = Array.length a2) then + raise (CircError "Comparison between arrays of different size") + else let rs = Array.map2 C.bvueq a1 a2 in hyps, {circ = BWCirc([C.ands (Array.to_list rs)]); inps = c1.inps @ c2.inps} | _ -> assert false @@ -1322,18 +1426,13 @@ let circuit_of_form begin match qnt with | Llambda -> hyps, {circ with inps=binds @ circ.inps} (* FIXME: check input order *) | Lforall - | Lexists -> assert false + | Lexists -> raise (CircError "Universal/Existential quantification not supported ") (* TODO: figure out how to handle quantifiers *) end | Fproj (f, i) -> let hyps, ftp = doit cache hyps f in hyps, circuit_tuple_proj ftp i - (* begin match f.f_node with *) - (* | Ftuple tp -> *) - (* doit cache hyps (tp |> List.drop (i-1) |> List.hd) *) - (* | _ -> failwith "Don't handle projections on non-tuples" *) - (* end *) - | Fmatch (f, fs, ty) -> assert false + | Fmatch (f, fs, ty) -> raise (CircError "Match not supported") | Flet (lpat, v, f) -> begin match lpat with | LSymbol (idn, ty) -> @@ -1346,7 +1445,7 @@ let circuit_of_form let hyps, tp = doit cache hyps v in let comps = if is_bwtuple tp.circ then circuits_of_circuit tp - else raise (CircError "tuple let") + else raise (CircError "tuple let type error") in (* Assuming types match coming from EC *) @@ -1356,22 +1455,21 @@ let circuit_of_form in doit cache hyps f - | LRecord (pth, osymbs) -> assert false + | LRecord (pth, osymbs) -> raise (CircError "record types not supported") end | Fpvar (pv, mem) -> let v = match pv with | PVloc v -> v - | _ -> failwith "No global vars yet" + | _ -> raise (CircError "Global vars not supported") + (* FIXME: Should globals be supported? *) in let res = match Map.find_opt v pstate with | Some circ -> circ | None -> raise (CircError (Format.sprintf "Uninitialized program variable %s" v)) - (* | None -> let circ = circ_ident (cinput_of_type ~idn:(create "uninit") env f_.f_ty) in *) - (* {circ with inps=[]} *) - (* EXPERIMENTAL: allowing unitialized values *) - (* failwith ("No value for var " ^ v) *) + (* FIXME: Do we add support for initialized PVs? With a check at the end that + the result does not depend on their value *) in hyps, res - | Fglob (id, mem) -> assert false + | Fglob (id, mem) -> raise (CircError "glob not supported") | Ftuple comps -> let hyps, comps = List.fold_left_map (fun hyps comp -> doit cache hyps comp) hyps comps @@ -1379,7 +1477,7 @@ let circuit_of_form let inps = List.fold_right (@) (List.map (fun c -> c.inps) comps) [] in let comps = List.map (fun c -> destr_bwcirc c.circ) comps in hyps, {circ= BWTuple comps; inps} - | _ -> failwith "Not yet implemented" + | _ -> raise (CircError "Unsupported form kind in translation") in @@ -1391,7 +1489,7 @@ let circuit_of_path (hyps: hyps) (p: path) : circuit = let f = EcEnv.Op.by_path p (toenv hyps) in let f = match f.op_kind with | OB_oper (Some (OP_Plain f)) -> f - | _ -> failwith "Invalid operator type" + | _ -> raise (CircError "Invalid operator type") in circuit_of_form hyps f @@ -1437,29 +1535,16 @@ let process_instr (hyps: hyps) (mem: memory) ?(cache: cache = Map.empty) (pstate let pstate = List.fold_left2 (fun pstate (pv, _ty) c -> let v = match pv with | PVloc v -> v - | _ -> assert false + | _ -> raise (CircError "Global variables not supported") in Map.add v c pstate ) pstate vs comps in pstate - (* begin match e.e_node with *) - (* | Etuple (es) -> List.fold_left2 (fun pstate (v, t) e -> *) - (* let v = match v with | PVloc v -> v | _ -> assert false in *) - (* Map.add v (form_of_expr mem e |> circuit_of_form ~pstate ~cache hyps) pstate) pstate vs es *) - (* | _ -> let c = (form_of_expr mem e |> circuit_of_form ~pstate ~cache hyps) in *) - (* assert (is_bwtuple c.circ); *) - (* let circs = circuits_of_circuit c in *) - (* assert (List.compare_lengths circs vs = 0); *) - (* let pstate = List.fold_left2 (fun pstate pv c -> *) - (* match pv with *) - (* | PVloc v -> Map.add v c pstate *) - (* | _ -> assert false *) - (* ) pstate (List.fst vs) circs in *) - (* pstate *) - - (* end *) - | _ -> failwith "Case not implemented yet" + | _ -> + let err = Format.asprintf "Instruction not supported: %a@." + (EcPrinting.pp_instr (EcPrinting.PPEnv.ofenv env)) inst in + raise (CircError err) with | e -> let bt = Printexc.get_backtrace () in @@ -1498,7 +1583,7 @@ let instrs_equiv let vs = EcPV.PV.elements pv |> fst in let vs = List.map (function | (PVloc v, ty) -> (v, ty) - | _ -> assert false + | _ -> raise (CircError "global variables not supported") ) vs in List.for_all (fun (var, ty) -> let circ1 = Map.find_opt var pstate1 in