Skip to content

Commit

Permalink
Fix problem with required proto2 messages
Browse files Browse the repository at this point in the history
  • Loading branch information
andersfugmann committed Jan 28, 2024
1 parent e1c2e1a commit 3d800ed
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 39 deletions.
55 changes: 31 additions & 24 deletions src/ocaml_protoc_plugin/deserialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ open S
type required = Required | Optional

type 'a reader = 'a -> Reader.t -> Field.field_type -> 'a
type 'a getter = 'a -> 'a
type ('a, 'b) getter = 'a -> 'b
type 'a field_spec = (int * 'a reader)
type 'a value = ('a field_spec list * required * 'a * 'a getter)
type _ value = Value: ('b field_spec list * required * 'b * ('b, 'a) getter) -> 'a value
type extensions = (int * Field.t) list

type (_, _) value_list =
| VNil : ('a, 'a) value_list
| VNil_ext : (extensions -> 'a, 'a) value_list
| VCons : ('a value) * ('b, 'c) value_list -> ('a -> 'b, 'c) value_list
| VCons : 'a value * ('b, 'c) value_list -> ('a -> 'b, 'c) value_list

type sentinel_field_spec = int * (Reader.t -> Field.field_type -> unit)
type 'a sentinel_getter = unit -> 'a
Expand Down Expand Up @@ -129,10 +129,17 @@ let read_field ~read:(expect, read_f) ~map v reader field_type =
error_wrong_field "Deserialize" field

let value: type a. a compound -> a value = function
| Basic (index, (Message (_, merge) as spec), None) ->
let map v1 v2 =
match v1 with
| None -> Some v2
| Some v1 -> Some (merge v1 v2)
in
let read = read_field ~read:(read_of_spec spec) ~map in
let getter = function Some v -> v | None -> failwith "Get called on unset required field" in
Value ([(index, read)], Required, None, getter)
| Basic (index, spec, default) ->
let map = match spec with
| Message (_, merge) -> merge
| _ -> keep_last
let map = keep_last
in
let read = read_field ~read:(read_of_spec spec) ~map in
let required = match default with
Expand All @@ -143,20 +150,20 @@ let value: type a. a compound -> a value = function
| None -> default_value spec
| Some default -> default
in
([(index, read)], required, default, id)
Value ([(index, read)], required, default, id)
| Basic_opt (index, spec) ->
let map = match spec with
| Message (_, merge) ->
let map v1 v2 =
match v1 with
| None -> Some v2
| Some prev -> Some (merge prev v2)
| Some v1 -> Some (merge v1 v2)
in
map
| _ -> fun _ v -> Some v (* Keep last for all other non-repeated types *)
in
let read = read_field ~read:(read_of_spec spec) ~map in
([(index, read)], Optional, None, id)
Value ([(index, read)], Optional, None, id)
| Repeated (index, spec, Packed) ->
let field_type, read_f = read_of_spec spec in
let rec read_packed_values read_f acc reader =
Expand All @@ -175,16 +182,16 @@ let value: type a. a compound -> a value = function
let field = Reader.read_field_content ft reader in
error_wrong_field "Deserialize" field
in
([(index, read)], Optional, [], List.rev)
Value ([(index, read)], Optional, [], List.rev)
| Repeated (index, spec, Not_packed) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun vs v -> v :: vs) in
([(index, read)], Optional, [], List.rev)
Value ([(index, read)], Optional, [], List.rev)
| Oneof oneofs ->
let make_reader: a oneof -> a field_spec = fun (Oneof_elem (index, spec, constr)) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun _ -> constr) in
(index, read)
in
(List.map ~f:make_reader oneofs, Optional, `not_set, id)
Value (List.map ~f:make_reader oneofs, Optional, `not_set, id)

module IntMap = Map.Make(struct type t = int let compare = Int.compare end)

Expand All @@ -197,7 +204,7 @@ let deserialize_full: type constr a. extension_ranges -> (constr, a) value_list
| VNil -> NNil
| VNil_ext -> NNil_ext
(* Consider optimizing when optional is true *)
| VCons ((fields, required, default, getter), rest) ->
| VCons (Value (fields, required, default, getter), rest) ->
let v = ref (default, required) in
let get () = match !v with
| _, Required -> error_required_field_missing ();
Expand Down Expand Up @@ -277,11 +284,11 @@ let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t
in

let rec read_values: type constr a. extension_ranges -> Field.field_type -> int -> Reader.t -> constr -> extensions -> (constr, a) value_list -> a = fun extension_ranges tpe idx reader constr extensions ->
let rec read_repeated tpe index read_f default get reader =
let rec read_repeated tpe index read_f default reader =
let default = read_f default reader tpe in
let (tpe, idx) = next_field reader in
match idx = index with
| true -> read_repeated tpe index read_f default get reader
| true -> read_repeated tpe index read_f default reader
| false -> default, tpe, idx
in
function
Expand All @@ -290,34 +297,34 @@ let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t
| VNil_ext when idx = Int.max_int ->
constr (List.rev extensions)
(* All fields read successfully. Apply extensions and return result. *)
| VCons (([index, read_f], _required, default, get), vs) when index = idx ->
| VCons (Value ([index, read_f], _required, default, get), vs) when index = idx ->
(* Read all values, and apply constructor once all fields have been read.
This pattern is the most likely to be matched for all values, and is added
as an optimization to avoid reconstructing the value list for each recursion.
*)
let default, tpe, idx = read_repeated tpe index read_f default get reader in
let default, tpe, idx = read_repeated tpe index read_f default reader in
let constr = (constr (get default)) in
read_values extension_ranges tpe idx reader constr extensions vs
| VCons (((index, read_f) :: fields, _required, default, get), vs) when index = idx ->
| VCons (Value ((index, read_f) :: fields, _required, default, get), vs) when index = idx ->
(* Read all values for the given field *)
let default, tpe, idx = read_repeated tpe index read_f default get reader in
read_values extension_ranges tpe idx reader constr extensions (VCons ((fields, Optional, default, get), vs))
let default, tpe, idx = read_repeated tpe index read_f default reader in
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, Optional, default, get), vs))
| vs when in_extension_ranges extension_ranges idx ->
(* Extensions may be sent inline. Store all valid extensions, before starting to apply constructors *)
let extensions = (idx, Reader.read_field_content tpe reader) :: extensions in
let (tpe, idx) = next_field reader in
read_values extension_ranges tpe idx reader constr extensions vs
| VCons (([], Required, _default, _get), _vs) ->
| VCons (Value ([], Required, _default, _get), _vs) ->
(* If there are no more fields to be read we will never find the value.
If all values are read, then raise, else revert to full deserialization *)
begin match (idx = Int.max_int) with
| true -> error_required_field_missing ()
| false -> raise Restart_full
end
| VCons ((_ :: fields, optional, default, get), vs) ->
| VCons (Value (_ :: fields, optional, default, get), vs) ->
(* Drop the field, as we dont expect to find it. *)
read_values extension_ranges tpe idx reader constr extensions (VCons ((fields, optional, default, get), vs))
| VCons (([], Optional, default, get), vs) ->
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, optional, default, get), vs))
| VCons (Value ([], Optional, default, get), vs) ->
(* Apply destructor. This case is only relevant for oneof fields *)
read_values extension_ranges tpe idx reader (constr (get default)) extensions vs
| VNil | VNil_ext ->
Expand Down
27 changes: 12 additions & 15 deletions src/plugin/types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -725,40 +725,37 @@ let make ~params ~syntax ~is_cyclic ~is_map_entry ~extension_ranges ~scope ~fiel
*)

let name = Scope.get_name scope name in
sprintf "let %s = match ((t1%s%s), (t2%s%s)) with" name sep name sep name ::
sprintf "match ((t1%s%s), (t2%s%s)) with"sep name sep name ::
List.map ~f:(fun (ctr, type') ->
let spec = sprintf "basic (0, %s, None)" type' in
sprintf " | (%s v1, %s v2) -> %s (Runtime'.Merge.merge Runtime'.Deserialize.C.( %s ) v1 v2)" ctr ctr ctr spec
) ctrs
|> append " | (v1, `not_set) -> v1"
|> append " | (_, v2) -> v2"
|> append "in"
|> String.concat ~sep:"\n"
|> fun value -> name, value

| { name; deserialize_spec; _ } ->
let name = Scope.get_name scope name in
sprintf "let %s = Runtime'.Merge.merge Runtime'.Deserialize.C.( %s ) t1%s%s t2%s%s in\n"
name deserialize_spec sep name sep name
name, sprintf "Runtime'.Merge.merge Runtime'.Deserialize.C.( %s ) t1%s%s t2%s%s"
deserialize_spec sep name sep name
)
|> append ~cond:has_extensions (sprintf "let extensions' = List.append t1%sextensions' t2%sextensions' in" sep sep)
|> String.concat ~sep:"\n"
|> append ~cond:has_extensions ("extensions'", sprintf "List.append t1%sextensions' t2%sextensions'" sep sep)
in
let constr =
let names =
List.map ts ~f:(fun c -> Scope.get_name scope c.name)
|> append ~cond:has_extensions "extensions'"
in
match as_tuple with
| true ->
names
List.map ~f:snd merge_values
|> String.concat ~sep:","
|> sprintf "(%s)"
| false ->
names
|> String.concat ~sep:"; "
|> sprintf "{ %s }"
List.map merge_values ~f:(fun (name, value) ->
Printf.sprintf "%s = (%s);" name value
)
|> String.concat ~sep:"\n"
|> sprintf "{\n%s\n }"
in
sprintf "fun %s -> \n%s\n%s" args merge_values constr
sprintf "fun %s -> %s" args constr
in

(* The type contains optional elements. We should not have those *)
Expand Down
13 changes: 13 additions & 0 deletions test/proto2.proto
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,16 @@ message Oneof_default {
int64 j = 2 [default = 7];
};
}

message NameClash {
message M1 { required int64 t = 1; };
message M2 { required int64 t = 1; };
message M3 { required int64 t = 1; };
message M4 { required int64 t = 1; };
message M5 { required int64 t = 1; };
required M1 t = 1;
required M2 T = 2;
required M3 _t = 3;
required M3 _T = 4;
required M4 T_ = 5;
}

0 comments on commit 3d800ed

Please sign in to comment.