Skip to content

Commit

Permalink
WIP POC WTF BBQ
Browse files Browse the repository at this point in the history
  • Loading branch information
davebenvenuti committed Jan 8, 2025
1 parent f537118 commit be1291c
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 122 deletions.
149 changes: 32 additions & 117 deletions lib/protoboeuf/codegen.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
require "erb"
require "syntax_tree"
require_relative "codegen_type_helper"
require_relative "decorated_field"

module ProtoBoeuf
class CodeGen
Expand Down Expand Up @@ -71,7 +72,7 @@ def result(message, toplevel_enums, generate_types:, requires:, syntax:, options
def initialize(message, toplevel_enums, generate_types:, requires:, syntax:, options:)
@message = message
@optional_field_bit_lut = []
@fields = @message.field
@fields = @message.field.map { |field| ::Protoboeuf::DecoratedField.new(message:, field:, syntax:) }
@enum_field_types = toplevel_enums.merge(message.enum_type.group_by(&:name))
@requires = requires
@generate_types = generate_types
Expand All @@ -87,9 +88,9 @@ def initialize(message, toplevel_enums, generate_types:, requires:, syntax:, opt

optional_field_count = 0

message.field.each do |field|
if optional_field?(field)
if field.type == :TYPE_ENUM
@fields.each do |field|
if field.optional?
if field.enum?
@enum_fields << field
else
@optional_fields << field
Expand All @@ -98,7 +99,7 @@ def initialize(message, toplevel_enums, generate_types:, requires:, syntax:, opt
optional_field_count += 1
elsif field.has_oneof_index?
(@oneof_fields[field.oneof_index] ||= []) << field
elsif field.type == :TYPE_ENUM
elsif field.enum?
@enum_fields << field
else
@required_fields << field
Expand All @@ -115,11 +116,6 @@ def initialize(message, toplevel_enums, generate_types:, requires:, syntax:, opt
end
end

def optional_field?(field)
proto3 = "proto3" == syntax
field.proto3_optional || (field.label == :LABEL_OPTIONAL && !proto3)
end

def result
"class #{message.name}\n" + class_body + "end\n"
end
Expand All @@ -141,7 +137,7 @@ def class_body

def conversion
fields = self.fields.reject do |field|
field.has_oneof_index? && !optional_field?(field)
field.has_oneof_index? && !field.optional?
end

oneofs = @oneof_selection_fields.map do |field|
Expand Down Expand Up @@ -178,7 +174,7 @@ def encode

def encode_subtype(field, value_expr = iv_name(field), tagged = true)
if field.label == :LABEL_REPEATED
if map_field?(field)
if field.map_field?
encode_map(field, value_expr, tagged)
else
encode_repeated(field, value_expr, tagged)
Expand All @@ -194,7 +190,7 @@ def encode_leaf_type(field, value_expr, tagged)

def encode_tag(field)
result = +""
tag = (field.number << 3) | CodeGen.wire_type(field)
tag = (field.number << 3) | field.wire_type
while tag != 0
byte = tag & 0x7F
tag >>= 7
Expand All @@ -209,7 +205,7 @@ def encode_tag(field)
def encode_length(field, len_expr)
result = +""

if CodeGen.wire_type(field) == LEN
if field.wire_type == ::Protoboeuf::DecoratedField::WIRE_TYPE_LEN
raise "length encoded fields must have a length expression" unless len_expr

if len_expr != "len"
Expand Down Expand Up @@ -249,16 +245,16 @@ def encode_bool(field, value_expr, tagged)
end

def encode_map(field, value_expr, tagged)
map_type = self.map_type(field)
map_type = field.map_type

<<~RUBY
map = #{value_expr}
if map.size > 0
old_buff = buff
map.each do |key, value|
buff = new_buffer = +''
#{encode_subtype(map_type.field[0], "key", true)}
#{encode_subtype(map_type.field[1], "value", true)}
#{encode_subtype(map_type.fields[0], "key", true)}
#{encode_subtype(map_type.fields[1], "value", true)}
buff = old_buff
#{encode_tag_and_length(field, true, "new_buffer.bytesize")}
old_buff.concat(new_buffer)
Expand All @@ -278,7 +274,7 @@ def encode_oneof(field, value_expr, tagged)
end

def encode_repeated(field, value_expr, tagged)
if CodeGen.packed?(field)
if field.packed?
<<~RUBY
list = #{value_expr}
if list.size > 0
Expand Down Expand Up @@ -650,9 +646,7 @@ def class_name(type)
end

def required_readers
fields = @required_fields.select do |field|
!field.type != :TYPE_ENUM
end
fields = @required_fields.reject(&:enum?)

"# required field readers\n" +
fields.map do |field|
Expand Down Expand Up @@ -764,7 +758,7 @@ def initialize_code
init_bitmask(message) +
initialize_oneofs +
fields.map { |field|
if field.has_oneof_index? && !optional_field?(field)
if field.has_oneof_index? && !field.optional?
initialize_oneof(field, message)
else
initialize_field(field)
Expand Down Expand Up @@ -793,7 +787,7 @@ def initialize_oneof(field, msg)
end

def initialize_field(field)
if optional_field?(field)
if field.optional?
initialize_optional_field(field)
elsif field.type == :TYPE_ENUM
initialize_enum_field(field)
Expand Down Expand Up @@ -1115,7 +1109,7 @@ def decode_from(buff, index, len)
<%= encode_varint("unknown_bytes") %>
case wire_type
when <%= VARINT %>
when <%= ::Protoboeuf::DecoratedField::WIRE_TYPE_VARINT %>
i = 0
while true
newbyte = buff.getbyte(index)
Expand All @@ -1126,18 +1120,18 @@ def decode_from(buff, index, len)
i += 1
break if i > 9
end
when <%= I64 %>
when <%= ::Protoboeuf::DecoratedField::WIRE_TYPE_I64 %>
unknown_bytes << buff.byteslice(index, 8)
index += 8
when <%= LEN %>
when <%= ::Protoboeuf::DecoratedField::WIRE_TYPE_LEN %>
value = <%= pull_varint %>
val = value
<%= encode_varint("unknown_bytes") %>
unknown_bytes << buff.byteslice(index, value)
index += value
when <%= I32 %>
when <%= ::Protoboeuf::DecoratedField::WIRE_TYPE_I32 %>
unknown_bytes << buff.byteslice(index, 4)
index += 4
else
Expand All @@ -1150,11 +1144,11 @@ def decode_from(buff, index, len)
found = false
<%- fields.each do |field| -%>
<%- if !field.has_oneof_index? || optional_field?(field) -%>
<%- if !field.has_oneof_index? || field.optional? -%>
if tag == <%= tag_for_field(field, field.number) %>
found = true
<%= decode_code(field) %>
<%= set_bitmask(field) if optional_field?(field) %>
<%= set_bitmask(field) if field.optional? %>
return self if index >= len
<%- if !reads_next_tag?(field) -%>
<%= pull_tag %>
Expand Down Expand Up @@ -1200,7 +1194,7 @@ def pull_tag

def default_for(field)
if field.label == :LABEL_REPEATED
if map_field?(field)
if field.map_field?
"{}"
else
"[]"
Expand All @@ -1225,25 +1219,9 @@ def default_for(field)
end
end

def map_field?(field)
return false unless field.label == :LABEL_REPEATED

map_name = field.type_name.split(".").last
message.nested_type.any? { |type| type.name == map_name && type.options&.map_entry }
end

def map_type(field)
return false unless field.label == :LABEL_REPEATED

map_name = field.type_name.split(".").last
message.nested_type.find do |type|
type.name == map_name && type.options&.map_entry
end || raise(ArgumentError, "Not a map field")
end

def initialize_signature
fields.flat_map do |f|
if f.has_oneof_index? || optional_field?(f)
if f.has_oneof_index? || f.optional?
"#{lvar_name(f)}: nil"
else
"#{lvar_name(f)}: #{default_for(f)}"
Expand All @@ -1252,7 +1230,7 @@ def initialize_signature
end

def tag_for_field(field, idx)
format("%#02x", (idx << 3 | CodeGen.wire_type(field)))
format("%#02x", (idx << 3 | field.wire_type))
end

def decode_subtype(field, type, dest, operator)
Expand Down Expand Up @@ -1308,18 +1286,17 @@ def pull_fixed_int32(dest, operator)
end

def decode_map(field)
map_type = self.map_type(field)
map_type = field.map_type

<<~RUBY
## PULL_MAP
map = #{iv_name(field)}
while tag == #{tag_for_field(field, field.number)}
#{pull_uint64("value", "=")}
index += 1 # skip the tag, assume it's the key
return self if index >= len
#{decode_subtype(map_type.field[0], map_type.field[0].type, "key", "=")}
#{decode_subtype(map_type.fields[0], map_type.fields[0].type, "key", "=")}
index += 1 # skip the tag, assume it's the value
#{decode_subtype(map_type.field[1], map_type.field[1].type, "map[key]", "=")}
#{decode_subtype(map_type.fields[1], map_type.fields[1].type, "map[key]", "=")}
return self if index >= len
#{pull_tag}
end
Expand Down Expand Up @@ -1616,9 +1593,9 @@ def pull_boolean(dest, operator)

def decode_code(field)
if field.label == :LABEL_REPEATED
if map_field?(field)
if field.map_field?
decode_map(field)
elsif CodeGen.packed?(field)
elsif field.packed?
PACKED_REPEATED.result(binding)
else
decode_repeated(field)
Expand Down Expand Up @@ -1675,11 +1652,7 @@ def test_bitmask(field)
end

def reads_next_tag?(field)
map_field?(field) || (repeated?(field) && !CodeGen.packed?(field))
end

def repeated?(field)
field.label == :LABEL_REPEATED
field.map_field? || (field.repeated? && !field.packed?)
end
end

Expand Down Expand Up @@ -1731,63 +1704,5 @@ def resolve_modules(file)
m.split("_").map(&:capitalize).join unless m.empty?
end
end

VARINT = 0
I64 = 1
LEN = 2
I32 = 5

PACKED_TYPES = [
:TYPE_DOUBLE,
:TYPE_FLOAT,
:TYPE_INT32,
:TYPE_INT64,
:TYPE_UINT32,
:TYPE_UINT64,
:TYPE_SINT32,
:TYPE_SINT64,
:TYPE_FIXED32,
:TYPE_FIXED64,
:TYPE_SFIXED32,
:TYPE_SFIXED64,
:TYPE_BOOL,
].to_set.freeze

class << self
# Returns whether or not a repeated field is packed.
# In Proto3 documents, repeated fields default to packed
def packed?(field)
raise ArgumentError unless field.label == :LABEL_REPEATED

return PACKED_TYPES.include?(field.type) unless field.options

field.options.packed
end

def wire_type(field)
if field.label == :LABEL_REPEATED && packed?(field)
LEN
elsif field.type == :TYPE_ENUM
VARINT
else
case field.type
when :TYPE_STRING, :TYPE_BYTES
LEN
when :TYPE_INT64, :TYPE_INT32, :TYPE_UINT64, :TYPE_BOOL, :TYPE_SINT32, :TYPE_SINT64, :TYPE_UINT32
VARINT
when :TYPE_DOUBLE, :TYPE_FIXED64, :TYPE_SFIXED64
I64
when :TYPE_FLOAT, :TYPE_FIXED32, :TYPE_SFIXED32
I32
when :TYPE_MESSAGE
LEN
# when /[A-Z]+\w+/ # FIXME: this doesn't seem right...
# LEN
else
raise "Unknown wire type for field #{field.type}"
end
end
end
end
end
end
10 changes: 5 additions & 5 deletions lib/protoboeuf/codegen_type_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def convert_type(converted_type, optional: false, array: false)
end

def convert_field_type(field)
converted_type = if map_field?(field)
map_type = self.map_type(field)
"T::Hash[#{convert_field_type(map_type.field[0])}, #{convert_field_type(map_type.field[1])}]"
converted_type = if field.map_field?
map_type = field.map_type
"T::Hash[#{convert_field_type(map_type.fields[0])}, #{convert_field_type(map_type.fields[1])}]"
else
case field.type
when :TYPE_BOOL
Expand All @@ -86,8 +86,8 @@ def convert_field_type(field)

convert_type(
converted_type,
optional: field.label == :TYPE_OPTIONAL,
array: field.label == :LABEL_REPEATED && !map_field?(field),
optional: field.optional?,
array: field.repeated? && field.map_field?,
)
end

Expand Down
Loading

0 comments on commit be1291c

Please sign in to comment.