Skip to content

Commit

Permalink
WIP to_json
Browse files Browse the repository at this point in the history
  • Loading branch information
davebenvenuti committed Jan 7, 2025
1 parent b4506bd commit d3f0f88
Show file tree
Hide file tree
Showing 16 changed files with 4,658 additions and 2,005 deletions.
207 changes: 171 additions & 36 deletions lib/protoboeuf/codegen.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# frozen_string_literal: true

# rubocop:disable Lint/LiteralInInterpolation

require "erb"
require "syntax_tree"
require_relative "codegen_type_helper"
Expand Down Expand Up @@ -155,6 +157,52 @@ def to_h
#{(oneofs + fields.map { |field| convert_field(field) }).join("\n")}
result
end
#{type_signature(params: { _options: "T::Hash" }, returns: "String")}
def to_json_without_debug(options = {})
require 'json'
obj = transform_for_json!(to_h)
JSON.generate(obj, options)
end
# sig: any
def to_json_with_debug(options = {})
require 'json'
obj = transform_for_json!(to_h)
#{"$stderr.puts \"to_json options: \#{options}\""}
JSON.generate(obj, options)
end
alias_method :to_json, :to_json_without_debug
# sig: any
private def transform_for_json!(obj)
case obj
when Hash
obj.each_with_object({}) do |(k, v), result|
result[json_field_name(k.to_s)] = transform_for_json!(v)
end
when Array
obj.map { |v| transform_for_json!(v) }
when String
# TODO: when field.type == :TYPE_BYTES
[obj].pack('m')
when Numeric
obj.to_s
else
obj
end
end
# By default the protobuf JSON printer should convert the field name to lowerCamelCase and use that as the JSON name.
# See: https://protobuf.dev/programming-guides/json/#json-options
private def json_field_name(name)
return name unless name.include?("_")
# Names like FIELD_NAME11 (all caps + underscores + numbers) should remain as-is
return name if name =~ /[A-Z\d_]+/
name.split(/_+/).each_with_index.map { |part, i| i.zero? ? part : part.downcase.capitalize }.join
end
RUBY
end

Expand All @@ -168,6 +216,31 @@ def convert_field(field)
end
end

# def convert_field_for_json(field)
# name = json_field_name(field.name.to_s)

# value_expr = if field.type == :TYPE_MESSAGE
# if repeated?(field)
# if false # map_field?(field)
# "%s.map(&:as_json)"
# else
# "%s.map(&:as_json)"
# end
# else
# "%s&.as_json"
# end
# elsif field.type == :TYPE_BYTES
# repeated?(field) ? "%s.map { |x| [x].pack('m') if x }" : "[%s].pack('m')"
# else
# "%s"
# end
# "result['#{name}'] = " + format(value_expr, iv_name(field))
# end

# def json_field_name(name)
# name.gsub(/_[a-z]/) { |m| m.delete_prefix("_").capitalize }
# end

def encode
# FIXME: we should probably sort fields by field number
type_signature(params: { buff: "String" }, returns: "String", newline: true) +
Expand Down Expand Up @@ -1074,14 +1147,31 @@ def oneof_field_readers
value = <%= pull_varint %>
# If value is even, then it's positive
<%= dest %> <%= operator %> (if value.even?
value = (if value.even?
value >> 1
else
-((value + 1) >> 1)
end)
<%= dest %> <%= operator %> value <%= " % #{TYPE_BOUNDS[:TYPE_SINT32]}" if truncate %>
## END PULL SINT32
ERB

PULL_SINT64 = ERB.new(<<~ERB, trim_mode: "-")
## PULL SINT64
value = <%= pull_varint %>
# If value is even, then it's positive
value = (if value.even?
value >> 1
else
-((value + 1) >> 1)
end)
<%= dest %> <%= operator %> value <%= "% #{TYPE_BOUNDS[:TYPE_SINT64]}" if truncate %>
## END PULL SINT64
ERB

DECODE_METHOD = ERB.new(<<~ERB, trim_mode: "-")
def decode_from(buff, index, len)
<%= init_bitmask(message) %>
Expand Down Expand Up @@ -1162,7 +1252,7 @@ def decode_from(buff, index, len)
list = <%= iv_name(field) %>
while true
break if index >= goal
<%= decode_subtype(field, field.type, "list", "<<") %>
<%= decode_subtype(field, field.type, "list", "<<", true) %>
end
ERB

Expand Down Expand Up @@ -1235,19 +1325,19 @@ def tag_for_field(field, idx)
format("%#02x", (idx << 3 | CodeGen.wire_type(field)))
end

def decode_subtype(field, type, dest, operator)
def decode_subtype(field, type, dest, operator, truncate = false)
if field.type == :TYPE_ENUM
pull_int64(dest, operator)
else
case type
when :TYPE_STRING then pull_string(dest, operator)
when :TYPE_BYTES then pull_bytes(dest, operator)
when :TYPE_UINT64 then pull_uint64(dest, operator)
when :TYPE_INT64 then pull_int64(dest, operator)
when :TYPE_INT32 then pull_int32(dest, operator)
when :TYPE_UINT32 then pull_uint32(dest, operator)
when :TYPE_SINT32 then pull_sint32(dest, operator)
when :TYPE_SINT64 then pull_sint64(dest, operator)
when :TYPE_UINT64 then pull_uint64(dest, operator, truncate)
when :TYPE_INT64 then pull_int64(dest, operator, truncate)
when :TYPE_INT32 then pull_int32(dest, operator, truncate)
when :TYPE_UINT32 then pull_uint32(dest, operator, truncate)
when :TYPE_SINT32 then pull_sint32(dest, operator, truncate)
when :TYPE_SINT64 then pull_sint64(dest, operator, truncate)
when :TYPE_BOOL then pull_boolean(dest, operator)
when :TYPE_DOUBLE then pull_double(dest, operator)
when :TYPE_FIXED64 then pull_fixed_int64(dest, operator)
Expand Down Expand Up @@ -1299,7 +1389,7 @@ def decode_map(field)
return self if index >= len
#{decode_subtype(map_type.field[0], map_type.field[0].type, "key", "=")}
index += 1 # skip the tag, assume it's the value
#{decode_subtype(map_type.field[1], map_type.field[1].type, "map[key]", "=")}
#{decode_subtype(map_type.field[1], map_type.field[1].type, "map[key]", "=", true)}
return self if index >= len
#{pull_tag}
end
Expand All @@ -1311,7 +1401,7 @@ def decode_repeated(field)
## DECODE REPEATED
list = #{iv_name(field)}
while true
#{decode_subtype(field, field.type, "list", "<<")}
#{decode_subtype(field, field.type, "list", "<<", true)}
return self if index >= len
#{pull_tag}
break unless tag == #{tag_for_field(field, field.number)}
Expand Down Expand Up @@ -1533,32 +1623,50 @@ def pull_message(type, dest, operator)
RUBY
end

def pull_int64(dest, operator)
<<~RUBY
## PULL_INT64
#{dest} #{operator} #{pull_varint(sign: :i64)}
## END PULL_INT64
RUBY
def pull_int64(dest, operator, truncate = false)
if truncate
<<~RUBY
## PULL_INT64
#{dest} #{operator} ((#{pull_varint(sign: :i64)}) & #{mask_for_truncation(64)})
## END PULL_INT64
RUBY
else
<<~RUBY
## PULL_INT64
#{dest} #{operator} #{pull_varint(sign: :i64)}
## END PULL_INT64
RUBY
end
end

def pull_int32(dest, operator)
<<~RUBY
## PULL_INT32
#{dest} #{operator} #{pull_varint(sign: :i32)}
## END PULL_INT32
RUBY
def pull_int32(dest, operator, truncate = false)
if truncate
<<~RUBY
## PULL_INT32
#{dest} #{operator} ((#{pull_varint(sign: :i32)}) & #{mask_for_truncation(32)})
## END PULL_INT32
RUBY
else
<<~RUBY
## PULL_INT32
#{dest} #{operator} #{pull_varint(sign: :i32)}
## END PULL_INT32
RUBY
end
end

def pull_sint32(dest, operator)
def pull_sint32(dest, operator, truncate = false)
PULL_SINT32.result(binding)
end

def pull_sint64(dest, operator, truncate = false)
PULL_SINT64.result(binding)
end

def pull_varint(sign: false)
PULL_VARINT.result(binding)
end

alias_method :pull_sint64, :pull_sint32

def pull_string(dest, operator)
<<~RUBY
## PULL_STRING
Expand All @@ -1575,15 +1683,37 @@ def pull_bytes(dest, operator)
RUBY
end

def pull_uint64(dest, operator)
<<~RUBY
## PULL_UINT64
#{dest} #{operator} #{pull_varint}
## END PULL_UINT64
RUBY
def pull_uint64(dest, operator, truncate = false)
if truncate
<<~RUBY
## PULL_UINT64
#{dest} #{operator} ((#{pull_varint}) & #{mask_for_truncation(64)})
## END PULL_UINT64
RUBY
else
<<~RUBY
## PULL_UINT64
#{dest} #{operator} #{pull_varint}
## END PULL_UINT64
RUBY
end
end

alias_method :pull_uint32, :pull_uint64
def pull_uint32(dest, operator, truncate = false)
if truncate
<<~RUBY
## PULL_UINT32
#{dest} #{operator} ((#{pull_varint}) & #{mask_for_truncation(32)})
## END PULL_UINT32
RUBY
else
<<~RUBY
## PULL_UINT32
#{dest} #{operator} #{pull_varint}
## END PULL_UINT32
RUBY
end
end

def pull_boolean(dest, operator)
<<~RUBY
Expand All @@ -1604,7 +1734,7 @@ def decode_code(field)
decode_repeated(field)
end
else
decode_subtype(field, field.type, iv_name(field), "=")
decode_subtype(field, field.type, iv_name(field), "=", true)
end
end

Expand Down Expand Up @@ -1661,6 +1791,10 @@ def reads_next_tag?(field)
def repeated?(field)
field.label == :LABEL_REPEATED
end

private def mask_for_truncation(bits)
"0x#{((1 << bits) - 1).to_s(16)}"
end
end

attr_reader :generate_types
Expand Down Expand Up @@ -1693,9 +1827,10 @@ def to_ruby(this_file = nil, options = {})

begin
return SyntaxTree.format(head + body + tail)
rescue
rescue => err
$stderr.puts head + body + tail
raise
File.open("shit_#{Time.now.to_i}.rb", "w") { |f| f.puts head + body + tail }
raise err
end
end
end
Expand Down
49 changes: 49 additions & 0 deletions lib/protoboeuf/protobuf/any.rb
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,55 @@ def to_h
result["value".to_sym] = @value
result
end

def to_json_without_debug(options = {})
require "json"
obj = transform_for_json!(to_h)
JSON.generate(obj, options)
end

# sig: any
def to_json_with_debug(options = {})
require "json"
obj = transform_for_json!(to_h)
$stderr.puts "to_json options: #{options}"
JSON.generate(obj, options)
end

alias_method :to_json, :to_json_without_debug

# sig: any
private def transform_for_json!(obj)
case obj
when Hash
obj.each_with_object({}) do |(k, v), result|
result[json_field_name(k.to_s)] = transform_for_json!(v)
end
when Array
obj.map { |v| transform_for_json!(v) }
when String
# TODO: when field.type == :TYPE_BYTES
[obj].pack("m")
when Numeric
obj.to_s
else
obj
end
end

# By default the protobuf JSON printer should convert the field name to lowerCamelCase and use that as the JSON name.
# See: https://protobuf.dev/programming-guides/json/#json-options
private def json_field_name(name)
return name unless name.include?("_")
# Names like FIELD_NAME11 (all caps + underscores + numbers) should remain as-is
return name if name =~ /[A-Zd_]+/

name
.split(/_+/)
.each_with_index
.map { |part, i| i.zero? ? part : part.downcase.capitalize }
.join
end
end
end
end
Loading

0 comments on commit d3f0f88

Please sign in to comment.