Skip to content

Commit

Permalink
Emit code for merging two messages and add tests to verify the implem…
Browse files Browse the repository at this point in the history
…entation
  • Loading branch information
andersfugmann committed Jan 28, 2024
1 parent 1a0f2a1 commit e1c2e1a
Show file tree
Hide file tree
Showing 20 changed files with 940 additions and 147 deletions.
22 changes: 18 additions & 4 deletions src/ocaml_protoc_plugin/deserialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ let read_of_spec: type a. a spec -> Field.field_type * (Reader.t -> a) = functio
let v = Bytes.create length in
Bytes.blit_string ~src:data ~src_pos:offset ~dst:v ~dst_pos:0 ~len:length;
v
| Message from_proto -> Length_delimited, fun reader ->
| Message (from_proto, _merge) -> Length_delimited, fun reader ->
let Field.{ offset; length; data } = Reader.read_length_delimited reader in
from_proto (Reader.create ~offset ~length data)

Expand All @@ -102,7 +102,7 @@ let default_value: type a. a spec -> a = function
| Fixed64 -> Int64.zero
| SFixed32 -> Int32.zero
| SFixed64 -> Int64.zero
| Message of_proto -> of_proto (Reader.create "")
| Message (of_proto, _merge) -> of_proto (Reader.create "")
| String -> ""
| Bytes -> Bytes.empty
| Int32_int -> 0
Expand Down Expand Up @@ -130,7 +130,11 @@ let read_field ~read:(expect, read_f) ~map v reader field_type =

let value: type a. a compound -> a value = function
| Basic (index, spec, default) ->
let read = read_field ~read:(read_of_spec spec) ~map:keep_last in
let map = match spec with
| Message (_, merge) -> merge
| _ -> keep_last
in
let read = read_field ~read:(read_of_spec spec) ~map in
let required = match default with
| Some _ -> Optional
| None -> Required
Expand All @@ -141,7 +145,17 @@ let value: type a. a compound -> a value = function
in
([(index, read)], required, default, id)
| Basic_opt (index, spec) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun _ v -> Some v) in
let map = match spec with
| Message (_, merge) ->
let map v1 v2 =
match v1 with
| None -> Some v2
| Some prev -> Some (merge prev v2)
in
map
| _ -> fun _ v -> Some v (* Keep last for all other non-repeated types *)
in
let read = read_field ~read:(read_of_spec spec) ~map in
([(index, read)], Optional, None, id)
| Repeated (index, spec, Packed) ->
let field_type, read_f = read_of_spec spec in
Expand Down
21 changes: 21 additions & 0 deletions src/ocaml_protoc_plugin/merge.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
(** Merge a two values. Need to match on the spec to merge messages recursivly *)
let merge: type t. t Spec.Deserialize.compound -> t -> t -> t = fun spec t t' -> match spec with
| Spec.Deserialize.Basic (_field, Message (_, merge), _) -> merge t t'
| Spec.Deserialize.Basic (_field, _spec, Some default) when t' = default -> t
| Spec.Deserialize.Basic (_field, _spec, _) -> t'
| Spec.Deserialize.Basic_opt (_field, Message (_, merge)) ->
begin
match t, t' with
| None, None -> None
| Some t, None -> Some t
| None, Some t -> Some t
| Some t, Some t' -> Some (merge t t')
end
| Spec.Deserialize.Basic_opt (_field, _spec) -> begin
match t' with
| Some _ -> t'
| None -> t
end
| Spec.Deserialize.Repeated (_field, _, _) -> t @ t'
(* | Spec.Deserialize.Oneof _ when t' = `not_set -> t *)
| Spec.Deserialize.Oneof _ -> failwith "Implementation is part of generated code"
1 change: 1 addition & 0 deletions src/ocaml_protoc_plugin/ocaml_protoc_plugin.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Serialize = Serialize
module Deserialize = Deserialize
module Spec = Spec
module Runtime = Runtime
module Field = Field
(**/**)

module Reader = Reader
Expand Down
1 change: 1 addition & 0 deletions src/ocaml_protoc_plugin/runtime.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ module Runtime' = struct
module Extensions = Extensions
module Reader = Reader
module Writer = Writer
module Merge = Merge
end
6 changes: 5 additions & 1 deletion src/ocaml_protoc_plugin/spec.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Make(T : T) = struct
type packed = Packed | Not_packed
type extension_ranges = (int * int) list
type extensions = (int * Field.t) list
type 'a merge = 'a -> 'a -> 'a

type _ spec =
| Double : float spec
Expand Down Expand Up @@ -40,7 +41,10 @@ module Make(T : T) = struct
| String : string spec
| Bytes : bytes spec
| Enum : ('a, int -> 'a, 'a -> int) T.dir -> 'a spec
| Message : ('a, Reader.t -> 'a, Writer.t -> 'a -> Writer.t) T.dir -> 'a spec
| Message : ('a, ((Reader.t -> 'a) * 'a merge), Writer.t -> 'a -> Writer.t) T.dir -> 'a spec

(* Existential types *)
type espec = Espec: _ spec -> espec

type _ oneof =
| Oneof_elem : int * 'b spec * ('a, ('b -> 'a), 'b) T.dir -> 'a oneof
Expand Down
9 changes: 7 additions & 2 deletions src/plugin/code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ let emit t indent fmt =
| n -> String.sub ~pos:0 ~len:(String.length s - n) s
in
let prepend s =
String.split_on_char ~sep:'\n' s
|> List.iter ~f:(fun s -> t.code <- (trim_end ~char:' ' (t.indent ^ s)) :: t.code)
match String.split_on_char ~sep:'\n' s with
| line :: lines ->
t.code <- (trim_end ~char:' ' (t.indent ^ line)) :: t.code;
incr t;
List.iter lines ~f:(fun line -> t.code <- (trim_end ~char:' ' (t.indent ^ line)) :: t.code);
decr t;
| [] -> ()
in
let emit s =
match indent with
Expand Down
8 changes: 5 additions & 3 deletions src/plugin/emit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,16 @@ let rec emit_message ~params ~syntax scope
| Some _name ->
let is_map_entry = is_map_entry options in
let is_cyclic = Scope.is_cyclic scope in
let Types.{ type'; constructor; apply; deserialize_spec; serialize_spec; default_constructor_sig; default_constructor_impl } =
let Types.{ type'; constructor; apply; deserialize_spec; serialize_spec;
default_constructor_sig; default_constructor_impl; merge_impl } =
Types.make ~params ~syntax ~is_cyclic ~is_map_entry ~extension_ranges ~scope ~fields oneof_decls
in
ignore (default_constructor_sig, default_constructor_impl);
ignore (merge_impl);

Code.emit signature `None "val name': unit -> string";
Code.emit signature `None "type t = %s %s" type' params.annot;
Code.emit signature `None "val make: %s" default_constructor_sig;
Code.emit signature `None "val merge: t -> t -> t";
Code.emit signature `None "val to_proto': Runtime'.Writer.t -> t -> Runtime'.Writer.t";
Code.emit signature `None "val to_proto: t -> Runtime'.Writer.t";
Code.emit signature `None "val from_proto: Runtime'.Reader.t -> (t, [> Runtime'.Result.error]) result";
Expand All @@ -227,6 +229,7 @@ let rec emit_message ~params ~syntax scope
Code.emit implementation `None "let name' () = \"%s\"" (Scope.get_current_scope scope);
Code.emit implementation `None "type t = %s%s" type' params.annot;
Code.emit implementation `None "let make %s" default_constructor_impl;
Code.emit implementation `None "let merge = (%s)" merge_impl;

Code.emit implementation `Begin "let to_proto' =";
Code.emit implementation `None "let spec = %s in" serialize_spec;
Expand All @@ -240,7 +243,6 @@ let rec emit_message ~params ~syntax scope
Code.emit implementation `None "let constructor = %s in" constructor;
Code.emit implementation `None "let spec = %s in" deserialize_spec;
Code.emit implementation `None "Runtime'.Deserialize.deserialize spec constructor";
(* TODO: No need to have a function here. We could drop deserialize thing here *)
Code.emit implementation `End "let from_proto writer = Runtime'.Result.catch (fun () -> from_proto_exn writer)";
| None -> ()
in
Expand Down
Loading

0 comments on commit e1c2e1a

Please sign in to comment.