diff --git a/README.md b/README.md index 0bc39ac..eb3df97 100644 --- a/README.md +++ b/README.md @@ -6,4 +6,10 @@ Chateau RPC protocol library and generator ```bash go install github.com/oringik/crypto-chateau/cmd/chateau-gen@latest +``` + +## Example generation + +```bash +chateau-gen -language=go -chateau_file=examples/reverse/contract/reverse.chateau -codegen_output=examples/reverse/codegen ``` \ No newline at end of file diff --git a/examples/reverse/client/client.go b/examples/reverse/client/client.go index 248be3a..e2549fa 100644 --- a/examples/reverse/client/client.go +++ b/examples/reverse/client/client.go @@ -3,6 +3,8 @@ package main import ( "context" "fmt" + "reflect" + endpoints "github.com/oringik/crypto-chateau/examples/reverse/codegen" ) @@ -14,10 +16,61 @@ func main() { resp, err := client.ReverseMagicString(context.Background(), &endpoints.ReverseMagicStringRequest{ MagicString: "privet kotik", + MagicInt8: 10, + MagicInt16: 20, + MagicInt32: 30, + MagicInt64: 40, + MagicUInt8: 50, + MagicUInt16: 60, + MagicUInt32: 70, + MagicUInt64: 80, + MagicBool: true, + MagicBytes: []byte{1, 2, 3, 4, 5}, + MagicObject: endpoints.ReverseCommonObject{ + Key: [16]byte{100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115}, + Value: [32]string{"hello", "world"}, + }, + MagicObjectArray: []endpoints.ReverseCommonObject{ + { + Key: [16]byte{100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115}, + Value: [32]string{"sub", "object"}, + }, + }, }) if err != nil { panic(err) } fmt.Println(resp.ReversedMagicString) + + excepted := &endpoints.ReverseMagicStringResponse{ + ReversedMagicString: "kitok tevirp", + MagicInt8: 110, + MagicInt16: 120, + MagicInt32: 130, + MagicInt64: 140, + MagicUInt8: 150, + MagicUInt16: 160, + MagicUInt32: 170, + MagicUInt64: 180, + MagicBool: false, + MagicBytes: []byte{2}, + MagicObject: endpoints.ReverseCommonObject{ + Key: [16]byte{100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115}, + Value: [32]string{"hello", "world"}, + }, + MagicObjectArray: []endpoints.ReverseCommonObject{ + { + Key: [16]byte{100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115}, + Value: [32]string{"sub", "object"}, + }, + }, + } + + ok := reflect.DeepEqual(resp, excepted) + + if !ok { + fmt.Printf("expected:\t%+v\ngot:\t\t%+v\n", excepted, resp) + panic("not equal") + } } diff --git a/examples/reverse/codegen/gen_definitions.go b/examples/reverse/codegen/gen_definitions.go index 58f1d47..78c72bd 100644 --- a/examples/reverse/codegen/gen_definitions.go +++ b/examples/reverse/codegen/gen_definitions.go @@ -1,87 +1,542 @@ -// CODEGEN VERSION: v1.0 +// Code generated by crypto-chateau 1.0.0 DO NOT EDIT. package endpoints -import "errors" -import "context" -import "strconv" -import "github.com/oringik/crypto-chateau/gen/conv" -import "github.com/oringik/crypto-chateau/peer" -import "github.com/oringik/crypto-chateau/message" -import "github.com/oringik/crypto-chateau/server" -import "go.uber.org/zap" -import "github.com/oringik/crypto-chateau/transport" -import "net" +import ( + "context" + "net" + "strconv" + + "github.com/pkg/errors" + "go.uber.org/zap" + + "github.com/oringik/crypto-chateau/gen/conv" + "github.com/oringik/crypto-chateau/gen/hash" + "github.com/oringik/crypto-chateau/message" + "github.com/oringik/crypto-chateau/peer" + "github.com/oringik/crypto-chateau/server" + "github.com/oringik/crypto-chateau/transport" +) var tagsByHandlerName = map[string]map[string]string{ "ReverseMagicString": {"key": "val", "ajsdajsd": "asdasd"}, } type Reverse interface { - ReverseMagicString(ctx context.Context, peer *peer.Peer, req *ReverseMagicStringRequest) error + ReverseMagicString(ctx context.Context, req *ReverseMagicStringRequest) (*ReverseMagicStringResponse, error) +} + +var handlerHashMap = map[string]map[string]hash.HandlerHash{ + "Reverse": { + "ReverseMagicString": hash.HandlerHash{0x52, 0x65, 0x76, 065}, + }, } -func ReverseMagicStringSqueeze(fnc func(context.Context, *peer.Peer, *ReverseMagicStringRequest) error) server.StreamFunc { - return func(ctx context.Context, peer *peer.Peer, msg message.Message) error { +func ReverseMagicStringSqueeze(fnc func(context.Context, *ReverseMagicStringRequest) (*ReverseMagicStringResponse, error)) server.HandlerFunc { + return func(ctx context.Context, msg message.Message) (message.Message, error) { if _, ok := msg.(*ReverseMagicStringRequest); ok { - return fnc(ctx, peer, msg.(*ReverseMagicStringRequest)) + return fnc(ctx, msg.(*ReverseMagicStringRequest)) } else { - return errors.New("unknown message type: expected ReverseMagicStringRequest") + return nil, errors.New("unknown message type: expected ReverseMagicStringRequest") } } } +type ReverseCommonObject struct { + Key [16]byte + Value [32]string +} + +var _ message.Message = (*ReverseCommonObject)(nil) + +func (o *ReverseCommonObject) Marshal() []byte { + var ( + arrBuf []byte + b = make([]byte, 0, 32) + ) + + size := conv.ConvertSizeToBytes(0) + b = append(b, size...) + arrBuf = make([]byte, 0, 128) + for _, elKey := range o.Key { + arrBuf = append(arrBuf, conv.ConvertByteToBytes(elKey)...) + } + b = append(b, conv.ConvertSizeToBytes(len(arrBuf))...) + b = append(b, arrBuf...) + arrBuf = make([]byte, 0, 128) + for _, elValue := range o.Value { + arrBuf = append(arrBuf, conv.ConvertSizeToBytes(len([]byte(elValue)))...) + arrBuf = append(arrBuf, conv.ConvertStringToBytes(elValue)...) + } + b = append(b, conv.ConvertSizeToBytes(len(arrBuf))...) + b = append(b, arrBuf...) + + size = conv.ConvertSizeToBytes(len(b) - len(size)) + for i := 0; i < len(size); i++ { + b[i] = size[i] + } + + return b +} + +func (o *ReverseCommonObject) Unmarshal(b *conv.BinaryIterator) error { + binaryCtx := struct { + err error + size, arrSize, pos int + buf, arrBuf *conv.BinaryIterator + }{} + + binaryCtx.err = nil + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read Key size") + } + binaryCtx.arrBuf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read Key") + } + binaryCtx.pos = 0 + for binaryCtx.arrBuf.HasNext() { + var elKey byte + + binaryCtx.buf, binaryCtx.err = binaryCtx.arrBuf.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read Key") + } + elKey = conv.ConvertBytesToByte(binaryCtx.buf) + + o.Key[binaryCtx.pos] = elKey + binaryCtx.pos++ + } + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read Value size") + } + binaryCtx.arrBuf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read Value") + } + binaryCtx.pos = 0 + for binaryCtx.arrBuf.HasNext() { + var elValue string + + binaryCtx.size, binaryCtx.err = binaryCtx.arrBuf.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read Value size") + } + binaryCtx.buf, binaryCtx.err = binaryCtx.arrBuf.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read Value") + } + elValue = conv.ConvertBytesToString(binaryCtx.buf) + + o.Value[binaryCtx.pos] = elValue + binaryCtx.pos++ + } + + return nil +} + type ReverseMagicStringRequest struct { - MagicString string + MagicString string + MagicInt8 int8 + MagicInt16 int16 + MagicInt32 int32 + MagicInt64 int64 + MagicUInt8 uint8 + MagicUInt16 uint16 + MagicUInt32 uint32 + MagicUInt64 uint64 + MagicBool bool + MagicBytes []byte + MagicObject ReverseCommonObject + MagicObjectArray []ReverseCommonObject } +var _ message.Message = (*ReverseMagicStringRequest)(nil) + func (o *ReverseMagicStringRequest) Marshal() []byte { - var buf []byte - buf = append(buf, '{') - var resultMagicString []byte - resultMagicString = append(resultMagicString, []byte("MagicString:")...) - resultMagicString = append(resultMagicString, conv.ConvertStringToBytes(o.MagicString)...) - buf = append(buf, resultMagicString...) - buf = append(buf, '}') - return buf -} - -func (o *ReverseMagicStringRequest) Unmarshal(params map[string][]byte) error { - o.MagicString = conv.ConvertBytesToString(params["MagicString"]) + var ( + arrBuf []byte + b = make([]byte, 0, 208) + ) + + size := conv.ConvertSizeToBytes(0) + b = append(b, size...) + b = append(b, conv.ConvertSizeToBytes(len([]byte(o.MagicString)))...) + b = append(b, conv.ConvertStringToBytes(o.MagicString)...) + b = append(b, conv.ConvertInt8ToBytes(o.MagicInt8)...) + b = append(b, conv.ConvertInt16ToBytes(o.MagicInt16)...) + b = append(b, conv.ConvertInt32ToBytes(o.MagicInt32)...) + b = append(b, conv.ConvertInt64ToBytes(o.MagicInt64)...) + b = append(b, conv.ConvertUint8ToBytes(o.MagicUInt8)...) + b = append(b, conv.ConvertUint16ToBytes(o.MagicUInt16)...) + b = append(b, conv.ConvertUint32ToBytes(o.MagicUInt32)...) + b = append(b, conv.ConvertUint64ToBytes(o.MagicUInt64)...) + b = append(b, conv.ConvertBoolToBytes(o.MagicBool)...) + arrBuf = make([]byte, 0, 128) + for _, elMagicBytes := range o.MagicBytes { + arrBuf = append(arrBuf, conv.ConvertByteToBytes(elMagicBytes)...) + } + b = append(b, conv.ConvertSizeToBytes(len(arrBuf))...) + b = append(b, arrBuf...) + b = append(b, o.MagicObject.Marshal()...) + arrBuf = make([]byte, 0, 128) + for _, elMagicObjectArray := range o.MagicObjectArray { + arrBuf = append(arrBuf, elMagicObjectArray.Marshal()...) + } + b = append(b, conv.ConvertSizeToBytes(len(arrBuf))...) + b = append(b, arrBuf...) + + size = conv.ConvertSizeToBytes(len(b) - len(size)) + for i := 0; i < len(size); i++ { + b[i] = size[i] + } + + return b +} + +func (o *ReverseMagicStringRequest) Unmarshal(b *conv.BinaryIterator) error { + binaryCtx := struct { + err error + size, arrSize, pos int + buf, arrBuf *conv.BinaryIterator + }{} + + binaryCtx.err = nil + + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicString size") + } + binaryCtx.buf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicString") + } + o.MagicString = conv.ConvertBytesToString(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicInt8") + } + o.MagicInt8 = conv.ConvertBytesToInt8(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(2) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicInt16") + } + o.MagicInt16 = conv.ConvertBytesToInt16(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(4) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicInt32") + } + o.MagicInt32 = conv.ConvertBytesToInt32(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(8) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicInt64") + } + o.MagicInt64 = conv.ConvertBytesToInt64(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicUInt8") + } + o.MagicUInt8 = conv.ConvertBytesToUint8(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(2) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicUInt16") + } + o.MagicUInt16 = conv.ConvertBytesToUint16(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(4) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicUInt32") + } + o.MagicUInt32 = conv.ConvertBytesToUint32(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(8) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicUInt64") + } + o.MagicUInt64 = conv.ConvertBytesToUint64(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicBool") + } + o.MagicBool = conv.ConvertBytesToBool(binaryCtx.buf) + + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicBytes size") + } + binaryCtx.arrBuf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicBytes") + } + for binaryCtx.arrBuf.HasNext() { + var elMagicBytes byte + + binaryCtx.buf, binaryCtx.err = binaryCtx.arrBuf.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicBytes") + } + elMagicBytes = conv.ConvertBytesToByte(binaryCtx.buf) + + o.MagicBytes = append(o.MagicBytes, elMagicBytes) + + } + + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObject size") + } + binaryCtx.buf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObject") + } + if binaryCtx.err = o.MagicObject.Unmarshal(binaryCtx.buf); binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to unmarshal MagicObject") + } + + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObjectArray size") + } + binaryCtx.arrBuf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObjectArray") + } + for binaryCtx.arrBuf.HasNext() { + var elMagicObjectArray ReverseCommonObject + + binaryCtx.size, binaryCtx.err = binaryCtx.arrBuf.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObjectArray size") + } + binaryCtx.buf, binaryCtx.err = binaryCtx.arrBuf.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObjectArray") + } + if binaryCtx.err = elMagicObjectArray.Unmarshal(binaryCtx.buf); binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to unmarshal MagicObjectArray") + } + + o.MagicObjectArray = append(o.MagicObjectArray, elMagicObjectArray) + + } + return nil } type ReverseMagicStringResponse struct { ReversedMagicString string + MagicInt8 int8 + MagicInt16 int16 + MagicInt32 int32 + MagicInt64 int64 + MagicUInt8 uint8 + MagicUInt16 uint16 + MagicUInt32 uint32 + MagicUInt64 uint64 + MagicBool bool + MagicBytes []byte + MagicObject ReverseCommonObject + MagicObjectArray []ReverseCommonObject } +var _ message.Message = (*ReverseMagicStringResponse)(nil) + func (o *ReverseMagicStringResponse) Marshal() []byte { - var buf []byte - buf = append(buf, '{') - var resultReversedMagicString []byte - resultReversedMagicString = append(resultReversedMagicString, []byte("ReversedMagicString:")...) - resultReversedMagicString = append(resultReversedMagicString, conv.ConvertStringToBytes(o.ReversedMagicString)...) - buf = append(buf, resultReversedMagicString...) - buf = append(buf, '}') - return buf -} - -func (o *ReverseMagicStringResponse) Unmarshal(params map[string][]byte) error { - o.ReversedMagicString = conv.ConvertBytesToString(params["ReversedMagicString"]) - return nil + var ( + arrBuf []byte + b = make([]byte, 0, 208) + ) + + size := conv.ConvertSizeToBytes(0) + b = append(b, size...) + b = append(b, conv.ConvertSizeToBytes(len([]byte(o.ReversedMagicString)))...) + b = append(b, conv.ConvertStringToBytes(o.ReversedMagicString)...) + b = append(b, conv.ConvertInt8ToBytes(o.MagicInt8)...) + b = append(b, conv.ConvertInt16ToBytes(o.MagicInt16)...) + b = append(b, conv.ConvertInt32ToBytes(o.MagicInt32)...) + b = append(b, conv.ConvertInt64ToBytes(o.MagicInt64)...) + b = append(b, conv.ConvertUint8ToBytes(o.MagicUInt8)...) + b = append(b, conv.ConvertUint16ToBytes(o.MagicUInt16)...) + b = append(b, conv.ConvertUint32ToBytes(o.MagicUInt32)...) + b = append(b, conv.ConvertUint64ToBytes(o.MagicUInt64)...) + b = append(b, conv.ConvertBoolToBytes(o.MagicBool)...) + arrBuf = make([]byte, 0, 128) + for _, elMagicBytes := range o.MagicBytes { + arrBuf = append(arrBuf, conv.ConvertByteToBytes(elMagicBytes)...) + } + b = append(b, conv.ConvertSizeToBytes(len(arrBuf))...) + b = append(b, arrBuf...) + b = append(b, o.MagicObject.Marshal()...) + arrBuf = make([]byte, 0, 128) + for _, elMagicObjectArray := range o.MagicObjectArray { + arrBuf = append(arrBuf, elMagicObjectArray.Marshal()...) + } + b = append(b, conv.ConvertSizeToBytes(len(arrBuf))...) + b = append(b, arrBuf...) + + size = conv.ConvertSizeToBytes(len(b) - len(size)) + for i := 0; i < len(size); i++ { + b[i] = size[i] + } + + return b } -func GetHandlers(reverse Reverse) map[string]*server.Handler { - handlers := make(map[string]*server.Handler) +func (o *ReverseMagicStringResponse) Unmarshal(b *conv.BinaryIterator) error { + binaryCtx := struct { + err error + size, arrSize, pos int + buf, arrBuf *conv.BinaryIterator + }{} + + binaryCtx.err = nil + + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read ReversedMagicString size") + } + binaryCtx.buf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read ReversedMagicString") + } + o.ReversedMagicString = conv.ConvertBytesToString(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicInt8") + } + o.MagicInt8 = conv.ConvertBytesToInt8(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(2) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicInt16") + } + o.MagicInt16 = conv.ConvertBytesToInt16(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(4) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicInt32") + } + o.MagicInt32 = conv.ConvertBytesToInt32(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(8) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicInt64") + } + o.MagicInt64 = conv.ConvertBytesToInt64(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicUInt8") + } + o.MagicUInt8 = conv.ConvertBytesToUint8(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(2) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicUInt16") + } + o.MagicUInt16 = conv.ConvertBytesToUint16(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(4) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicUInt32") + } + o.MagicUInt32 = conv.ConvertBytesToUint32(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(8) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicUInt64") + } + o.MagicUInt64 = conv.ConvertBytesToUint64(binaryCtx.buf) + + binaryCtx.buf, binaryCtx.err = b.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicBool") + } + o.MagicBool = conv.ConvertBytesToBool(binaryCtx.buf) + + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicBytes size") + } + binaryCtx.arrBuf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicBytes") + } + for binaryCtx.arrBuf.HasNext() { + var elMagicBytes byte + + binaryCtx.buf, binaryCtx.err = binaryCtx.arrBuf.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicBytes") + } + elMagicBytes = conv.ConvertBytesToByte(binaryCtx.buf) + + o.MagicBytes = append(o.MagicBytes, elMagicBytes) + + } + + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObject size") + } + binaryCtx.buf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObject") + } + if binaryCtx.err = o.MagicObject.Unmarshal(binaryCtx.buf); binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to unmarshal MagicObject") + } + + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObjectArray size") + } + binaryCtx.arrBuf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObjectArray") + } + for binaryCtx.arrBuf.HasNext() { + var elMagicObjectArray ReverseCommonObject + + binaryCtx.size, binaryCtx.err = binaryCtx.arrBuf.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObjectArray size") + } + binaryCtx.buf, binaryCtx.err = binaryCtx.arrBuf.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read MagicObjectArray") + } + if binaryCtx.err = elMagicObjectArray.Unmarshal(binaryCtx.buf); binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to unmarshal MagicObjectArray") + } + + o.MagicObjectArray = append(o.MagicObjectArray, elMagicObjectArray) + + } + + return nil +} +func GetHandlers(reverse Reverse) map[hash.HandlerHash]*server.Handler { + handlers := make(map[hash.HandlerHash]*server.Handler) - var callFuncReverseMagicString server.StreamFunc + var callFuncReverseMagicString server.HandlerFunc if reverse != nil { callFuncReverseMagicString = ReverseMagicStringSqueeze(reverse.ReverseMagicString) } - handlers["ReverseMagicString"] = &server.Handler{ - CallFuncStream: callFuncReverseMagicString, - HandlerType: server.StreamT, + handlers[hash.HandlerHash{0x52, 0x65, 0x76, 065}] = &server.Handler{ + CallFuncHandler: callFuncReverseMagicString, + HandlerType: server.HandlerT, RequestMsgType: &ReverseMagicStringRequest{}, ResponseMsgType: &ReverseMagicStringResponse{}, Tags: tagsByHandlerName["ReverseMagicString"], @@ -90,11 +545,11 @@ func GetHandlers(reverse Reverse) map[string]*server.Handler { return handlers } -func GetEmptyHandlers() map[string]*server.Handler { - handlers := make(map[string]*server.Handler) +func GetEmptyHandlers() map[hash.HandlerHash]*server.Handler { + handlers := make(map[hash.HandlerHash]*server.Handler) - handlers["ReverseMagicString"] = &server.Handler{ - HandlerType: server.StreamT, + handlers[hash.HandlerHash{0x52, 0x65, 0x76, 065}] = &server.Handler{ + HandlerType: server.HandlerT, RequestMsgType: &ReverseMagicStringRequest{}, ResponseMsgType: &ReverseMagicStringResponse{}, } @@ -111,8 +566,12 @@ func NewServer(cfg *server.Config, logger *zap.Logger, reverse Reverse) *server. func CallClientMethod(ctx context.Context, host string, port int, serviceName string, methodName string, req message.Message) (message.Message, error) { if serviceName == "Reverse" { if methodName == "ReverseMagicString" { + client, err := NewClientReverse(host, port) + if err != nil { + return nil, err + } + return client.ReverseMagicString(ctx, req.(*ReverseMagicStringRequest)) } - } return nil, errors.New("unknown service or method") @@ -136,11 +595,49 @@ func NewClientReverse(host string, port int) (*ClientReverse, error) { return client, nil } -func (c *ClientReverse) ReverseMagicString(ctx context.Context, req *ReverseMagicStringRequest) (*peer.Peer, error) { - err := c.peer.WriteResponse("ReverseMagicString", req) +func (c *ClientReverse) ReverseMagicString(ctx context.Context, req *ReverseMagicStringRequest) (*ReverseMagicStringResponse, error) { + err := c.peer.WriteResponse(hash.HandlerHash{0x52, 0x65, 0x76, 065}, req) + + msg := make([]byte, 0, 1024) + + for { + buf := make([]byte, 1024) + n, err := c.peer.Read(buf) + if err != nil { + return nil, err + } + + if n == 0 { + break + } + + if n < len(buf) { + buf = buf[:n] + msg = append(msg, buf...) + break + } + + msg = append(msg, buf...) + } + + _, _, offset, err := conv.GetHandler(msg) + if err != nil { + return nil, err + } + respMsg := &ReverseMagicStringResponse{} + + //TODO: check if error is present + + // check if message has a size + if len(msg) < offset+conv.ObjectBytesPrefixLength { + return nil, errors.New("not enough for size and message") + } + + err = respMsg.Unmarshal(conv.NewBinaryIterator(msg[offset+conv.ObjectBytesPrefixLength:])) if err != nil { return nil, err } - return c.peer, nil + + return respMsg, nil } diff --git a/examples/reverse/contract/reverse.chateau b/examples/reverse/contract/reverse.chateau index 18a2ba4..0e7ecd6 100644 --- a/examples/reverse/contract/reverse.chateau +++ b/examples/reverse/contract/reverse.chateau @@ -1,13 +1,42 @@ package endpoints service Reverse { - Stream ReverseMagicString(ReverseMagicStringRequest req) -> (ReverseMagicStringResponse) {key: val, ajsdajsd: asdasd} + Handler ReverseMagicString(ReverseMagicStringRequest req) -> (ReverseMagicStringResponse) {key: val, ajsdajsd: asdasd} +} + +object ReverseCommonObject { + [16]byte Key + [32]string Value } object ReverseMagicStringRequest { string magicString + int8 magicInt8 + int16 magicInt16 + int32 magicInt32 + int64 magicInt64 + uint8 magicUInt8 + uint16 magicUInt16 + uint32 magicUInt32 + uint64 magicUInt64 + bool magicBool + []byte magicBytes + ReverseCommonObject magicObject + []ReverseCommonObject magicObjectArray } object ReverseMagicStringResponse { string reversedMagicString + int8 magicInt8 + int16 magicInt16 + int32 magicInt32 + int64 magicInt64 + uint8 magicUInt8 + uint16 magicUInt16 + uint32 magicUInt32 + uint64 magicUInt64 + bool magicBool + []byte magicBytes + ReverseCommonObject magicObject + []ReverseCommonObject magicObjectArray } diff --git a/examples/reverse/server/server.go b/examples/reverse/server/server.go index 7c1e304..1dc5c33 100644 --- a/examples/reverse/server/server.go +++ b/examples/reverse/server/server.go @@ -2,10 +2,12 @@ package main import ( "context" + "time" + + zap2 "go.uber.org/zap" + endpoints "github.com/oringik/crypto-chateau/examples/reverse/codegen" server2 "github.com/oringik/crypto-chateau/server" - zap2 "go.uber.org/zap" - "time" ) func main() { @@ -48,5 +50,17 @@ func (r *ReverseEndpoint) ReverseMagicString(ctx context.Context, req *endpoints return &endpoints.ReverseMagicStringResponse{ ReversedMagicString: reversedMsg, + MagicInt8: req.MagicInt8 + 100, + MagicInt16: req.MagicInt16 + 100, + MagicInt32: req.MagicInt32 + 100, + MagicInt64: req.MagicInt64 + 100, + MagicUInt8: req.MagicUInt8 + 100, + MagicUInt16: req.MagicUInt16 + 100, + MagicUInt32: req.MagicUInt32 + 100, + MagicUInt64: req.MagicUInt64 + 100, + MagicBool: !req.MagicBool, + MagicBytes: req.MagicBytes[1:2], + MagicObject: req.MagicObject, + MagicObjectArray: req.MagicObjectArray, }, nil } diff --git a/gen/ast/ast.go b/gen/ast/ast.go index d9841d6..f1738db 100644 --- a/gen/ast/ast.go +++ b/gen/ast/ast.go @@ -1,64 +1,72 @@ package ast import ( - lexem2 "github.com/oringik/crypto-chateau/gen/lexem" "strconv" "strings" + + "github.com/oringik/crypto-chateau/gen/hash" + lexem2 "github.com/oringik/crypto-chateau/gen/lexem" ) type Type int const ( - Uint32 Type = iota - Uint64 + Uint64 Type = iota + Uint32 Uint16 Uint8 - Int8 - Int32 Int64 + Int32 + Int16 + Int8 Byte Bool String Object ) -var lexerTypeToAstType = map[string]Type{ - "byte": Byte, - "uint32": Uint32, +var LexerTypeToAstType = map[string]Type{ "uint64": Uint64, - "uint8": Uint8, + "uint32": Uint32, "uint16": Uint16, - "string": String, - "bool": Bool, - "int8": Int8, - "int32": Int32, + "uint8": Uint8, "int64": Int64, + "int32": Int32, + "int16": Int16, + "int8": Int8, + "byte": Byte, + "bool": Bool, + "string": String, + "object": Object, } var AstTypeToGoType = map[Type]string{ - Uint32: "uint32", Uint64: "uint64", + Uint32: "uint32", Uint16: "uint16", Uint8: "uint8", + Int64: "int64", + Int32: "int32", + Int16: "int16", + Int8: "int8", Byte: "byte", - String: "string", Bool: "bool", - Int8: "int8", - Int32: "int32", - Int64: "int64", + String: "string", + Object: "object", } var AstTypeToDartType = map[Type]string{ - Uint32: "int", Uint64: "int", + Uint32: "int", Uint16: "int", Uint8: "int", + Int64: "int", + Int32: "int", + Int16: "int", + Int8: "int", Byte: "int", - String: "String", Bool: "bool", - Int8: "int", - Int32: "int", - Int64: "int", + String: "String", } type MethodType string @@ -86,6 +94,7 @@ type Service struct { type Method struct { Name string + Hash hash.HandlerHash Params []*Param Returns []*Return MethodType MethodType @@ -244,14 +253,14 @@ func astField() *Field { var astType Type isArr, arrSize := getArrExistAndSize(lexem.Value) if isArr { - astTypeLocal, ok := lexerTypeToAstType[lexem.Value[2+getCountDigits(arrSize):]] + astTypeLocal, ok := LexerTypeToAstType[lexem.Value[2+getCountDigits(arrSize):]] if !ok { panic("unexpected type") } astType = astTypeLocal } else { - astTypeLocal, ok := lexerTypeToAstType[lexem.Value] + astTypeLocal, ok := LexerTypeToAstType[lexem.Value] if !ok { panic("unexpected type " + lexem.Value) } @@ -292,9 +301,10 @@ func astService() *Service { getNextLexem() var methods []*Method for lexem.Type != lexem2.CloseBraceL { - astMethod := astMethod() + method := astMethod() + method.Hash = hash.GetHandlerHash(service.Name, method.Name) - methods = append(methods, astMethod) + methods = append(methods, method) } service.Methods = methods @@ -449,14 +459,14 @@ func astParam() *Param { var astType Type isArr, arrSize := getArrExistAndSize(lexem.Value) if isArr { - astTypeLocal, ok := lexerTypeToAstType[lexem.Value[2+getCountDigits(arrSize):]] + astTypeLocal, ok := LexerTypeToAstType[lexem.Value[2+getCountDigits(arrSize):]] if !ok { panic("unexpected type") } astType = astTypeLocal } else { - astTypeLocal, ok := lexerTypeToAstType[lexem.Value] + astTypeLocal, ok := LexerTypeToAstType[lexem.Value] if !ok { panic("unexpected type") } @@ -501,14 +511,14 @@ func astReturn() *Return { var astType Type isArr, arrSize := getArrExistAndSize(lexem.Value) if isArr { - astTypeLocal, ok := lexerTypeToAstType[lexem.Value[2+getCountDigits(arrSize):]] + astTypeLocal, ok := LexerTypeToAstType[lexem.Value[2+getCountDigits(arrSize):]] if !ok { panic("unexpected type") } astType = astTypeLocal } else { - astTypeLocal, ok := lexerTypeToAstType[lexem.Value] + astTypeLocal, ok := LexerTypeToAstType[lexem.Value] if !ok { panic("unexpected type") } diff --git a/gen/conv/binary_iterator.go b/gen/conv/binary_iterator.go new file mode 100644 index 0000000..90097b2 --- /dev/null +++ b/gen/conv/binary_iterator.go @@ -0,0 +1,49 @@ +package conv + +import ( + "encoding/binary" + "errors" +) + +var ( + ErrNotEnoughBytes = errors.New("not enough bytes") +) + +type BinaryIterator struct { + Bytes []byte + Index int +} + +func NewBinaryIterator(b []byte) *BinaryIterator { + return &BinaryIterator{ + Bytes: b, + Index: 0, + } +} + +func (b *BinaryIterator) NextSize() (int, error) { + if b.Index+4 > len(b.Bytes) { + return 0, ErrNotEnoughBytes + } + result := binary.BigEndian.Uint32(b.Bytes[b.Index : b.Index+4]) + b.Index += 4 + + return int(result), nil +} + +func (b *BinaryIterator) Slice(n int) (*BinaryIterator, error) { + if b.Index+n > len(b.Bytes) { + return nil, ErrNotEnoughBytes + } + result := &BinaryIterator{ + Bytes: b.Bytes[b.Index : b.Index+n], + Index: 0, + } + b.Index += n + + return result, nil +} + +func (b *BinaryIterator) HasNext() bool { + return b.Index < len(b.Bytes) +} diff --git a/gen/conv/bool.go b/gen/conv/bool.go new file mode 100644 index 0000000..f185e93 --- /dev/null +++ b/gen/conv/bool.go @@ -0,0 +1,13 @@ +package conv + +func ConvertBytesToBool(b *BinaryIterator) bool { + return b.Bytes[0] == 0x01 +} + +func ConvertBoolToBytes(b bool) []byte { + if b { + return []byte{0x01} + } + + return []byte{0x00} +} diff --git a/gen/conv/conv.go b/gen/conv/conv.go index 08f2df0..9f23280 100644 --- a/gen/conv/conv.go +++ b/gen/conv/conv.go @@ -1,340 +1,19 @@ package conv import ( - "bytes" - "encoding/binary" "errors" - "fmt" - "github.com/oringik/crypto-chateau/gen/ast" - "github.com/oringik/crypto-chateau/message" -) - -func ConvFunctionMarhsalByType(t ast.Type) string { - if t == ast.Object { - return "ConvertObjectToBytes" - } - - if t == ast.Uint8 { - return "ConvertUint8ToBytes" - } - - if t == ast.Uint32 { - return "ConvertUint32ToBytes" - } - - if t == ast.Uint64 { - return "ConvertUint64ToBytes" - } - - if t == ast.String { - return "ConvertStringToBytes" - } - - if t == ast.Bool { - return "ConvertBoolToBytes" - } - - if t == ast.Byte { - return "ConvertByteToBytes" - } - - if t == ast.Uint16 { - return "ConvertUint16ToBytes" - } - - if t == ast.Int8 { - return "ConvertInt8ToBytes" - } - - if t == ast.Int32 { - return "ConvertInt32ToBytes" - } - - if t == ast.Int64 { - return "ConvertInt64ToBytes" - } - - return "" -} - -func ConvFunctionUnmarshalByType(t ast.Type) string { - if t == ast.Object { - return "ConvertBytesToObject" - } - - if t == ast.Uint8 { - return "ConvertBytesToUint8" - } - - if t == ast.Uint32 { - return "ConvertBytesToUint32" - } - - if t == ast.Uint64 { - return "ConvertBytesToUint64" - } - - if t == ast.String { - return "ConvertBytesToString" - } - - if t == ast.Bool { - return "ConvertBytesToBool" - } - - if t == ast.Byte { - return "ConvertBytesToByte" - } - - if t == ast.Uint16 { - return "ConvertBytesToUint16" - } - - if t == ast.Int8 { - return "ConvertBytesToInt8" - } - - if t == ast.Int32 { - return "ConvertBytesToInt32" - } - - if t == ast.Int64 { - return "ConvertBytesToInt64" - } - - return "" -} - -func ConvertInt8ToBytes(num int8) []byte { - return []byte{byte(num)} -} - -func ConvertBytesToInt8(b []byte) int8 { - return int8(b[0]) -} - -func ConvertInt32ToBytes(num int32) []byte { - buf := make([]byte, 4) - binary.BigEndian.PutUint32(buf, uint32(num)) - - return buf -} - -func ConvertBytesToInt32(b []byte) int32 { - return int32(binary.BigEndian.Uint32(b)) -} - -func ConvertInt64ToBytes(num int64) []byte { - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, uint64(num)) - - return buf -} - -func ConvertBytesToInt64(b []byte) int64 { - return int64(binary.BigEndian.Uint64(b)) -} - -func ConvertUint16ToBytes(num uint16) []byte { - buf := make([]byte, 2) - binary.BigEndian.PutUint16(buf, num) - - return buf -} -func ConvertBytesToUint16(b []byte) uint16 { - return binary.BigEndian.Uint16(b) -} - -func ConvertByteToBytes(b byte) []byte { - return []byte{b} -} - -func ConvertBytesToObject(msg message.Message, b []byte) { - _, params, err := GetParams(b) - if err != nil { - fmt.Println(err) - } - err = msg.Unmarshal(params) - if err != nil { - fmt.Println(err) - } -} - -func ConvertBytesToUint8(b []byte) uint8 { - return b[0] -} - -func ConvertBytesToUint32(b []byte) uint32 { - return binary.BigEndian.Uint32(b) -} - -func ConvertBytesToUint64(b []byte) uint64 { - return binary.BigEndian.Uint64(b) -} - -func ConvertBytesToString(b []byte) string { - return string(b) -} - -func ConvertBoolToString(b []byte) bool { - if b[0] == '1' { - return true - } - - return false -} - -func ConvertUint8ToBytes(num uint8) []byte { - return []byte{num} -} - -func ConvertBytesToByte(b []byte) byte { - return b[0] -} - -func ConvertUint32ToBytes(num uint32) []byte { - buf := make([]byte, 4) - binary.BigEndian.PutUint32(buf, num) - - return buf -} - -func ConvertUint64ToBytes(num uint64) []byte { - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, num) - - return buf -} - -func ConvertStringToBytes(str string) []byte { - return []byte(str) -} - -func ConvertBoolToBytes(b bool) []byte { - if b { - return []byte{'1'} - } - - return []byte{'0'} -} - -func ConvertObjectToBytes(msg message.Message) []byte { - return msg.Marshal() -} - -func GetHandlerName(p []byte) ([]byte, int, error) { - buf := make([]byte, 0, 50) - for i, b := range p { - if b == '#' { - return buf, i + 1, nil - } - - buf = append(buf, b) - } - - return nil, 0, errors.New("incorrect message format: handler name not found") -} - -func GetArray(p []byte) (int, [][]byte, error) { - if len(p) == 0 { - return 0, nil, errors.New("array is zero length") - } - - if p[0] != '[' { - return 0, nil, errors.New("expected open brace") - } - - openSquareBracketCount := 1 - var closeSquareBracketCount int - - i := 1 - for openSquareBracketCount != closeSquareBracketCount && i < len(p) { - if p[i] == '[' { - openSquareBracketCount++ - } - - if p[i] == ']' { - closeSquareBracketCount++ - } - - i++ - } - - if openSquareBracketCount != closeSquareBracketCount { - return 0, nil, errors.New("expected end of array") - } + "github.com/oringik/crypto-chateau/gen/hash" +) - values := bytes.Split(p[1:i], []byte(",")) - for i, value := range values { - values[i] = bytes.TrimSpace(value) +func GetHandler(p []byte) (protocol []byte, handlerKey hash.HandlerHash, payloadOffset int, err error) { + if len(p) < 6 { + return nil, hash.HandlerHash{}, 0, errors.New("invalid payload: too short") } - return i, values, nil -} - -func GetParams(p []byte) (int, map[string][]byte, error) { - params := make(map[string][]byte) - paramBuf := make([]byte, 0, len(p)) - valueBuf := make([]byte, 0, len(p)) - - paramBufLast := -1 - valueBufLast := -1 - - var paramFilled bool - var stringParsing bool - - var openBraceCount int - var closeBraceCount int - - var openSquareBracketCount int - var closeSquareBracketCount int - - var isArrParsing bool - - for i, b := range p { - if (b == ',' && paramBufLast != len(paramBuf)-1 && valueBufLast != len(valueBuf)-1 && openBraceCount == closeBraceCount+1 && (!isArrParsing || openSquareBracketCount == closeSquareBracketCount)) || (b == '}' && openBraceCount == closeBraceCount+1) { - if b == '}' && i != len(p)-1 { - valueBuf = append(valueBuf, b) - closeBraceCount++ - } - if paramBufLast == len(paramBuf)-1 || valueBufLast == len(valueBuf)-1 { - return 0, nil, errors.New("incorrect message format: null value") - } - - params[string(paramBuf[paramBufLast+1:])] = valueBuf[valueBufLast+1:] - paramBufLast = len(paramBuf) - 1 - valueBufLast = len(valueBuf) - 1 - - paramFilled = false - isArrParsing = false - } else if b == '[' { - valueBuf = append(valueBuf, '[') - isArrParsing = true - openSquareBracketCount++ - } else if b == ']' { - valueBuf = append(valueBuf, ']') - closeSquareBracketCount++ - } else if b == '{' { - if paramFilled { - valueBuf = append(valueBuf, b) - } - openBraceCount++ - } else if b == '}' { - valueBuf = append(valueBuf, b) - closeBraceCount++ - } else if b == ':' && stringParsing == false && !paramFilled { - paramFilled = true - } else if b == '"' { - stringParsing = !stringParsing - } else { - if !paramFilled { - paramBuf = append(paramBuf, b) - } else { - valueBuf = append(valueBuf, b) - } - } - } + protocol = p[:1] + handlerBytes := p[1:5] + handlerKey = hash.HandlerHash{handlerBytes[0], handlerBytes[1], handlerBytes[2], handlerBytes[3]} - return len(p), params, nil + return protocol, handlerKey, 5, nil } diff --git a/gen/conv/conv_test.go b/gen/conv/conv_test.go deleted file mode 100644 index 9db145e..0000000 --- a/gen/conv/conv_test.go +++ /dev/null @@ -1,34 +0,0 @@ -package conv - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func Test_ConvGetObjectArray(t *testing.T) { - _, params, err := GetParams([]byte(`[{a: "bjkjkjk"}, {b: 37878}]`)) - assert.NoError(t, err) - toCompare := []map[string][]byte{ - {"a": []byte("bjkjkjk")}, - {"b": []byte("37878")}, - } - assert.Equal(t, len(params), len(toCompare)) - for i, _ := range params { - assert.Equal(t, params[i], toCompare[i], 0) - } -} - -func Test_ConvGetArray(t *testing.T) { - _, params, err := GetArray([]byte(`[1, 2, 3, 4]`)) - assert.NoError(t, err) - toCompare := [][]byte{ - []byte("1"), - []byte("2"), - []byte("3"), - []byte("4"), - } - assert.Equal(t, len(params), len(toCompare)) - for i, _ := range params { - assert.Equal(t, params[i], toCompare[i], 0) - } -} diff --git a/gen/conv/int16.go b/gen/conv/int16.go new file mode 100644 index 0000000..ac34c14 --- /dev/null +++ b/gen/conv/int16.go @@ -0,0 +1,25 @@ +package conv + +import "encoding/binary" + +func ConvertBytesToUint16(b *BinaryIterator) uint16 { + return binary.BigEndian.Uint16(b.Bytes) +} + +func ConvertUint16ToBytes(num uint16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, num) + + return buf +} + +func ConvertBytesToInt16(b *BinaryIterator) int16 { + return int16(binary.BigEndian.Uint16(b.Bytes)) +} + +func ConvertInt16ToBytes(num int16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, uint16(num)) + + return buf +} diff --git a/gen/conv/int32.go b/gen/conv/int32.go new file mode 100644 index 0000000..58c531d --- /dev/null +++ b/gen/conv/int32.go @@ -0,0 +1,25 @@ +package conv + +import "encoding/binary" + +func ConvertBytesToUint32(b *BinaryIterator) uint32 { + return binary.BigEndian.Uint32(b.Bytes) +} + +func ConvertUint32ToBytes(num uint32) []byte { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, num) + + return buf +} + +func ConvertBytesToInt32(b *BinaryIterator) int32 { + return int32(binary.BigEndian.Uint32(b.Bytes)) +} + +func ConvertInt32ToBytes(num int32) []byte { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(num)) + + return buf +} diff --git a/gen/conv/int64.go b/gen/conv/int64.go new file mode 100644 index 0000000..a88fffb --- /dev/null +++ b/gen/conv/int64.go @@ -0,0 +1,25 @@ +package conv + +import "encoding/binary" + +func ConvertBytesToUint64(b *BinaryIterator) uint64 { + return binary.BigEndian.Uint64(b.Bytes) +} + +func ConvertUint64ToBytes(num uint64) []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, num) + + return buf +} + +func ConvertBytesToInt64(b *BinaryIterator) int64 { + return int64(binary.BigEndian.Uint64(b.Bytes)) +} + +func ConvertInt64ToBytes(num int64) []byte { + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(num)) + + return buf +} diff --git a/gen/conv/int8.go b/gen/conv/int8.go new file mode 100644 index 0000000..9f9bc6d --- /dev/null +++ b/gen/conv/int8.go @@ -0,0 +1,25 @@ +package conv + +func ConvertBytesToUint8(b *BinaryIterator) uint8 { + return b.Bytes[0] +} + +func ConvertUint8ToBytes(num uint8) []byte { + return []byte{num} +} + +func ConvertBytesToInt8(b *BinaryIterator) int8 { + return int8(b.Bytes[0]) +} + +func ConvertInt8ToBytes(num int8) []byte { + return []byte{byte(num)} +} + +func ConvertBytesToByte(b *BinaryIterator) byte { + return b.Bytes[0] +} + +func ConvertByteToBytes(num byte) []byte { + return []byte{num} +} diff --git a/gen/conv/size.go b/gen/conv/size.go new file mode 100644 index 0000000..92202c7 --- /dev/null +++ b/gen/conv/size.go @@ -0,0 +1,14 @@ +package conv + +import "encoding/binary" + +var ( + ObjectBytesPrefixLength = len(ConvertSizeToBytes(0)) +) + +func ConvertSizeToBytes(num int) []byte { + buf := make([]byte, 4) + binary.BigEndian.PutUint32(buf, uint32(num)) + + return buf +} diff --git a/gen/conv/string.go b/gen/conv/string.go new file mode 100644 index 0000000..598049e --- /dev/null +++ b/gen/conv/string.go @@ -0,0 +1,9 @@ +package conv + +func ConvertBytesToString(b *BinaryIterator) string { + return string(b.Bytes) +} + +func ConvertStringToBytes(str string) []byte { + return []byte(str) +} diff --git a/gen/conv/string_test.go b/gen/conv/string_test.go new file mode 100644 index 0000000..942a725 --- /dev/null +++ b/gen/conv/string_test.go @@ -0,0 +1,12 @@ +package conv + +import "testing" + +func TestString(t *testing.T) { + str := "Hello World!" + b := ConvertStringToBytes(str) + str2 := ConvertBytesToString(&BinaryIterator{Bytes: b}) + if str != str2 { + t.Errorf("Expected %s, got %s", str, str2) + } +} diff --git a/gen/gen/gen_dart.go b/gen/gen/gen_dart.go index a8282e4..aa1ebc4 100644 --- a/gen/gen/gen_dart.go +++ b/gen/gen/gen_dart.go @@ -2,11 +2,11 @@ package gen import ( "fmt" - ast2 "github.com/oringik/crypto-chateau/gen/ast" - "github.com/oringik/crypto-chateau/gen/conv" "strconv" "strings" "unicode" + + ast2 "github.com/oringik/crypto-chateau/gen/ast" ) var resultDart string @@ -89,7 +89,8 @@ func fillObjectsDart() { resultDart += "\t\tList buf = List.empty(growable: true);\n" resultDart += "\t\t" + `buf.addAll('{'.codeUnits);` + "\n" for i, field := range object.Fields { - convFunction := conv.ConvFunctionMarhsalByType(field.Type.Type) + //convFunction := conv.ConvFunctionMarhsalByType(field.Type.Type) + convFunction := "TODO_CHANGE" resultDart += fmt.Sprintf("\t\tList resultDart%s = List.empty(growable: true);\n", field.Name) if field.Type.IsArray { resultDart += fmt.Sprintf("\t\t"+`resultDart%s.addAll('['.codeUnits);`, field.Name) + "\n" @@ -118,7 +119,8 @@ func fillObjectsDart() { resultDart += "\tUnmarshal(Map params) {\n" for _, field := range object.Fields { - convFunction := conv.ConvFunctionUnmarshalByType(field.Type.Type) + //convFunction := conv.ConvFunctionUnmarshalByType(field.Type.Type) + convFunction := "TODO_CHANGE" if field.Type.Type == ast2.Object { if field.Type.IsArray { resultDart += fmt.Sprintf("\t\t\t"+`var arr = GetArray(params["%s"]!)[1];`+"\n", strings.Title(field.Name)) diff --git a/gen/gen/gen_go.go b/gen/gen/gen_go.go index 2691d97..fce0ec3 100644 --- a/gen/gen/gen_go.go +++ b/gen/gen/gen_go.go @@ -2,16 +2,17 @@ package gen import ( "fmt" + "go/format" "strconv" - "strings" "unicode" + "github.com/pkg/errors" + ast2 "github.com/oringik/crypto-chateau/gen/ast" - "github.com/oringik/crypto-chateau/gen/conv" + "github.com/oringik/crypto-chateau/gen/templates" + "github.com/oringik/crypto-chateau/version" ) -const CODEGEN_VERSION string = "v1.0" - var result string var ast *ast2.Ast @@ -23,8 +24,12 @@ func GenerateDefinitions(astLocal *ast2.Ast) string { fillImports() fillTagsByHandlerName() fillServices() + fillHandlerHashMap() fillSqueezes() - fillObjects() + err := fillObjects() + if err != nil { + panic("objects generation failed:" + err.Error()) + } fillGetHandlers() fillEmptyGetHandlers() fillNewServer() @@ -32,20 +37,34 @@ func GenerateDefinitions(astLocal *ast2.Ast) string { fillClients() + resultRaw, err := format.Source([]byte(result)) + if err != nil { + panic("formatting failed:" + err.Error()) + } + + result = string(resultRaw) + return result } func fillImports() { - result += "import \"errors\"\n" - result += "import \"context\"\n" - result += "import \"strconv\"\n" - result += "import \"github.com/oringik/crypto-chateau/gen/conv\"\n" - result += "import \"github.com/oringik/crypto-chateau/peer\"\n" - result += "import \"github.com/oringik/crypto-chateau/message\"\n" - result += "import \"github.com/oringik/crypto-chateau/server\"\n" - result += "import \"go.uber.org/zap\"\n" - result += "import \"github.com/oringik/crypto-chateau/transport\"\n" - result += "import \"net\"\n\n" + result += ` + import ( + "context" + "net" + "strconv" + + "github.com/pkg/errors" + "go.uber.org/zap" + + "github.com/oringik/crypto-chateau/gen/conv" + "github.com/oringik/crypto-chateau/gen/hash" + "github.com/oringik/crypto-chateau/message" + "github.com/oringik/crypto-chateau/peer" + "github.com/oringik/crypto-chateau/server" + "github.com/oringik/crypto-chateau/transport" + ) +` } @@ -54,7 +73,7 @@ func fillPackage() { } func fillVersion() { - result += "// CODEGEN VERSION: " + CODEGEN_VERSION + "\n\n" + result += "// Code generated by crypto-chateau " + version.CodegenVersion + " DO NOT EDIT.\n\n" } func fillClients() { @@ -124,7 +143,7 @@ func fillClients() { } result += "error)" result += "{\n" - result += "\terr := c.peer.WriteResponse(\"" + method.Name + "\"," + method.Params[0].Name + ")\n\n" + result += "\terr := c.peer.WriteResponse(" + method.Hash.Code() + ", " + method.Params[0].Name + ")\n\n" if method.MethodType != ast2.Stream { result += fmt.Sprintf(`msg := make([]byte, 0, 1024) @@ -149,23 +168,21 @@ func fillClients() { msg = append(msg, buf...) } - _, n, err := conv.GetHandlerName(msg) + _, _, offset, err := conv.GetHandler(msg) if err != nil { return nil, err } - if n >= len(msg) { - return nil, errors.New("incorrect message") - } + respMsg := &%s{} - _, responseMsgParams, err := conv.GetParams(msg[n:]) - if err != nil { - return nil, err - } + //TODO: check if error is present - respMsg := &%s{} + // check if message has a size + if len(msg) < offset+conv.ObjectBytesPrefixLength { + return nil, errors.New("not enough for size and message") + } - err = respMsg.Unmarshal(responseMsgParams) + err = respMsg.Unmarshal(conv.NewBinaryIterator(msg[offset+conv.ObjectBytesPrefixLength:])) if err != nil { return nil, err } @@ -261,6 +278,18 @@ func fillServices() { } } +func fillHandlerHashMap() { + result += "var handlerHashMap = map[string]map[string]hash.HandlerHash{\n" + for _, service := range ast.Chateau.Services { + for _, method := range service.Methods { + result += "\t\"" + service.Name + "\":{\n" + result += "\t\t\"" + method.Name + "\":" + method.Hash.Code() + ",\n" + result += "\t},\n" + } + } + result += "}\n\n" +} + func fillSqueezes() { for _, service := range ast.Chateau.Services { for _, method := range service.Methods { @@ -289,110 +318,23 @@ func fillSqueezes() { } } -func fillObjects() { - for _, object := range ast.Chateau.ObjectDefinitions { - result += "type " + object.Name + " struct {\n" - for _, field := range object.Fields { - field.Name = strings.Title(field.Name) - field.Type.ObjectName = strings.Title(field.Type.ObjectName) - - result += "\t" + field.Name + " " - - if field.Type.IsArray { - result += "[" - if field.Type.ArrSize != 0 { - result += strconv.Itoa(field.Type.ArrSize) - } - result += "]" - } - - if field.Type.Type == ast2.Object { - result += "*" + field.Type.ObjectName + "\n" - } else { - result += ast2.AstTypeToGoType[field.Type.Type] + "\n" - } - } - result += "}\n\n" +func fillObjects() error { + ot, err := templates.NewObjectTemplate() + if err != nil { + return errors.Wrap(err, "failed to create object template") + } - // marshal - - result += "func (o *" + object.Name + ") Marshal() []byte {\n" - result += "\tvar buf []byte\n" - result += "\t" + `buf = append(buf, '{')` + "\n" - for i, field := range object.Fields { - convFunction := conv.ConvFunctionMarhsalByType(field.Type.Type) - result += fmt.Sprintf("\tvar result%s []byte\n", field.Name) - result += fmt.Sprintf("\t"+`result%s = append(result%s, []byte("%s:")...)`, field.Name, field.Name, field.Name) + "\n" - if field.Type.IsArray { - result += fmt.Sprintf("\t"+`result%s = append(result%s, '[')`, field.Name, field.Name) + "\n" - result += "\tfor i, val := range o." + field.Name + " {\n" - result += fmt.Sprintf("\t\tresult%s = append(result%s, conv.%s(val)...)\n", field.Name, field.Name, convFunction) - result += "\t\tif i != len(o." + field.Name + ") - 1 {\n" - result += "\t\t\t" + fmt.Sprintf(`result%s = append(result%s, ',')`, field.Name, field.Name) + "\n" - result += "\t\t}\n" - result += "\t}\n" - result += fmt.Sprintf("\t"+`result%s = append(result%s, ']')`, field.Name, field.Name) + "\n\n" - } else { - result += fmt.Sprintf("\tresult%s = append(result%s, conv.%s(o.%s)...)\n", field.Name, field.Name, convFunction, field.Name) - } - result += fmt.Sprintf("\tbuf = append(buf, result%s...)\n", field.Name) - if i != len(object.Fields)-1 { - result += "\tbuf = append(buf, ',')\n" - } + var genObject string + for _, object := range ast.Chateau.ObjectDefinitions { + genObject, err = ot.Gen(object) + if err != nil { + return errors.Wrap(err, "failed to generate object") } - result += "\t" + `buf = append(buf, '}')` + "\n" - result += "\treturn buf\n }\n\n" - - // unmarshal - - alreadyWasArraySymb := ":" - - result += "func (o *" + object.Name + ") Unmarshal(params map[string][]byte) error {\n" - for _, field := range object.Fields { - convFunction := conv.ConvFunctionUnmarshalByType(field.Type.Type) - if field.Type.Type == ast2.Object { - if field.Type.IsArray { - result += fmt.Sprintf("\t"+`_, arr, err %s= conv.GetArray(params["%s"])`+"\n", alreadyWasArraySymb, field.Name) - result += "\tif err != nil {\n\t\treturn err\n\t}\n" - result += "\tfor _, objBytes := range arr {\n" - result += "\t\tvar curObj *" + field.Type.ObjectName + "\n" - result += fmt.Sprintf("\t\t"+`conv.%s(curObj,objBytes)`+"\n", convFunction) - result += fmt.Sprintf("\t\to.%s = append(o.%s, curObj)\n", field.Name, field.Name) - result += "\t}\n" - - alreadyWasArraySymb = "" - } else { - result += fmt.Sprintf("\to.%s = &%s{}\n", field.Name, field.Type.ObjectName) - result += fmt.Sprintf("\t"+`conv.%s(o.%s,params["%s"])`+"\n", convFunction, field.Name, field.Name) - } - } else { - if field.Type.IsArray { - result += fmt.Sprintf("\t"+`_, arr, err %s= conv.GetArray(params["%s"])`+"\n", alreadyWasArraySymb, field.Name) - result += "\tif err != nil {\n\t\treturn err\n\t}\n" - var iOrMiss string - if field.Type.ArrSize != 0 { - iOrMiss = "i" - } else { - iOrMiss = "_" - } - result += fmt.Sprintf("\tfor %s, valByte := range arr {\n", iOrMiss) - if field.Type.ArrSize != 0 { - result += fmt.Sprintf("\t\to.%s[i] = conv.%s(valByte)\n", field.Name, convFunction) - } else { - result += "\t\tvar curVal " + ast2.AstTypeToGoType[field.Type.Type] + "\n" - result += fmt.Sprintf("\t\t"+`curVal = conv.%s(valByte)`+"\n", convFunction) - result += fmt.Sprintf("\t\to.%s = append(o.%s, curVal)\n", field.Name, field.Name) - } - result += "\t}\n" - alreadyWasArraySymb = "" - } else { - result += fmt.Sprintf("\t"+`o.%s = conv.%s(params["%s"])`+"\n", field.Name, convFunction, field.Name) - } - } - } - result += "\treturn nil\n}\n\n" + result += genObject } + + return nil } func fillGetHandlers() { @@ -408,8 +350,8 @@ func fillGetHandlers() { endpointArgs += "," } } - result += fmt.Sprintf("func GetHandlers(%s) map[string]*server.Handler {\n", endpointArgs) - result += "\thandlers := make(map[string]*server.Handler)\n\n" + result += fmt.Sprintf("func GetHandlers(%s) map[hash.HandlerHash]*server.Handler {\n", endpointArgs) + result += "\thandlers := make(map[hash.HandlerHash]*server.Handler)\n\n" for _, service := range ast.Chateau.Services { for _, method := range service.Methods { var methodType string @@ -427,13 +369,13 @@ func fillGetHandlers() { result += "\tif " + serviceNameLower + " != nil {\n" result += "\t\tcallFunc" + method.Name + "= " + method.Name + "Squeeze(" + serviceNameLower + "." + method.Name + ")\n" result += "\t}\n\n" - result += fmt.Sprintf("\t"+`handlers["%s"] = &server.Handler{ + result += fmt.Sprintf("\t"+`handlers[%s] = &server.Handler{ CallFunc%s: callFunc%s, HandlerType: %s, RequestMsgType: &%s{}, ResponseMsgType: &%s{}, Tags: tagsByHandlerName["%s"], - }`+"\n\n", method.Name, string(method.MethodType), method.Name, methodType, method.Params[0].Type.ObjectName, method.Returns[0].Type.ObjectName, method.Name) + }`+"\n\n", method.Hash.Code(), string(method.MethodType), method.Name, methodType, method.Params[0].Type.ObjectName, method.Returns[0].Type.ObjectName, method.Name) } } result += "\treturn handlers\n" @@ -453,8 +395,8 @@ func fillEmptyGetHandlers() { endpointArgs += "," } } - result += "func GetEmptyHandlers() map[string]*server.Handler {\n" - result += "\thandlers := make(map[string]*server.Handler)\n\n" + result += "func GetEmptyHandlers() map[hash.HandlerHash]*server.Handler {\n" + result += "\thandlers := make(map[hash.HandlerHash]*server.Handler)\n\n" for _, service := range ast.Chateau.Services { for _, method := range service.Methods { var methodType string @@ -468,11 +410,11 @@ func fillEmptyGetHandlers() { if len(service.Name) > 1 { serviceNameLower += service.Name[1:] } - result += fmt.Sprintf("\t"+`handlers["%s"] = &server.Handler{ + result += fmt.Sprintf("\t"+`handlers[%s] = &server.Handler{ HandlerType: %s, RequestMsgType: &%s{}, ResponseMsgType: &%s{}, - }`+"\n\n", method.Name, methodType, method.Params[0].Type.ObjectName, method.Returns[0].Type.ObjectName) + }`+"\n\n", method.Hash.Code(), methodType, method.Params[0].Type.ObjectName, method.Returns[0].Type.ObjectName) } } result += "\treturn handlers\n" diff --git a/gen/hash/handler.go b/gen/hash/handler.go new file mode 100644 index 0000000..3a1eab8 --- /dev/null +++ b/gen/hash/handler.go @@ -0,0 +1,19 @@ +package hash + +import ( + "crypto/sha256" + "fmt" +) + +type HandlerHash [4]byte + +func (h HandlerHash) Code() string { + return fmt.Sprintf("hash.HandlerHash{0x%X, 0x%X, 0x%X, 0%X}", h[0], h[1], h[2], h[3]) +} + +// GetHandlerHash returns first 4 bytes of sha256 hash of serviceName/handlerName +func GetHandlerHash(serviceName string, handlerName string) [4]byte { + hash := sha256.New().Sum([]byte(serviceName + "/" + handlerName)) + + return [4]byte{hash[0], hash[1], hash[2], hash[3]} +} diff --git a/gen/lexem/lexem.go b/gen/lexem/lexem.go index 7b8c0e2..4c0554b 100644 --- a/gen/lexem/lexem.go +++ b/gen/lexem/lexem.go @@ -20,7 +20,7 @@ const ( ObjectL ) -var typeIdentifiers = []string{"byte", "int8", "int32", "int64", "uint32", "uint64", "uint8", "uint16", "string", "bool", "object"} +var typeIdentifiers = []string{"byte", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "bool", "string", "object"} var pairTypes = map[string]LexemType{"service": ServiceL, ":": ColonL, "(": OpenParenL, ")": CloseParenL, ",": CommaL, "->": ReturnArrowL, "{": OpenBraceL, "}": CloseBraceL, "Handler": MethodL, "Stream": MethodL, "package": PackageL, "object": ObjectL} type Lexem struct { diff --git a/gen/templates/embed.go b/gen/templates/embed.go new file mode 100644 index 0000000..e6f4d84 --- /dev/null +++ b/gen/templates/embed.go @@ -0,0 +1,6 @@ +package templates + +import "embed" + +//go:embed object.go.tpl +var embFS embed.FS diff --git a/gen/templates/object.go b/gen/templates/object.go new file mode 100644 index 0000000..46c489a --- /dev/null +++ b/gen/templates/object.go @@ -0,0 +1,107 @@ +package templates + +import ( + "bytes" + "strconv" + "strings" + "text/template" + + "github.com/iancoleman/strcase" + "github.com/pkg/errors" + + "github.com/oringik/crypto-chateau/gen/ast" +) + +type ObjectTemplate struct { + tpl *template.Template +} + +func NewObjectTemplate() (*ObjectTemplate, error) { + tpl := template.New("object") + tpl = tpl.Funcs(objectTemplateFunc) + tpl, err := tpl.ParseFS(embFS, "object.go.tpl") + if err != nil { + return nil, errors.Wrap(err, "failed to parse object template") + } + + return &ObjectTemplate{ + tpl: tpl, + }, nil +} + +func (t *ObjectTemplate) Gen(definition *ast.ObjectDefinition) (string, error) { + b := bytes.NewBuffer(nil) + + err := t.tpl.ExecuteTemplate(b, "object.go.tpl", definition) + if err != nil { + return "", errors.Wrap(err, "failed to execute object template") + } + + return b.String(), nil +} + +var objectTemplateFunc = template.FuncMap{ + "mul": func(a, b int) int { return a * b }, + "dict": func(values ...interface{}) (map[string]interface{}, error) { + if len(values)%2 != 0 { + return nil, errors.New("invalid dict call") + } + dict := make(map[string]interface{}, len(values)/2) + for i := 0; i < len(values); i += 2 { + key, ok := values[i].(string) + if !ok { + return nil, errors.New("dict keys must be strings") + } + dict[key] = values[i+1] + } + return dict, nil + }, + "eqType": func(a ast.Type, b string) bool { + return strings.EqualFold(ast.AstTypeToGoType[a], b) + }, + "GoType": GoType, + "ToCamel": strcase.ToCamel, +} + +func GoType(t *ast.TypeLink, noArray ...bool) (string, error) { + var textType string + + switch t.Type { + case ast.Uint64: + textType = "uint64" + case ast.Uint32: + textType = "uint32" + case ast.Uint16: + textType = "uint16" + case ast.Uint8: + textType = "uint8" + case ast.Int64: + textType = "int64" + case ast.Int32: + textType = "int32" + case ast.Int16: + textType = "int16" + case ast.Int8: + textType = "int8" + case ast.Byte: + textType = "byte" + case ast.Bool: + textType = "bool" + case ast.String: + textType = "string" + case ast.Object: + textType = strcase.ToCamel(t.ObjectName) + default: + return "", errors.New("unknown type: " + strconv.Itoa(int(t.Type))) + } + + if !t.IsArray || (len(noArray) > 0 && noArray[0]) { + return textType, nil + } + + if t.ArrSize == 0 { + return "[]" + textType, nil + } + + return "[" + strconv.Itoa(t.ArrSize) + "]" + textType, nil +} diff --git a/gen/templates/object.go.tpl b/gen/templates/object.go.tpl new file mode 100644 index 0000000..f87c104 --- /dev/null +++ b/gen/templates/object.go.tpl @@ -0,0 +1,199 @@ +{{- /*gotype: github.com/oringik/crypto-chateau/gen/ast.ObjectDefinition*/ -}} + +type {{.Name | ToCamel}} struct { +{{- range .Fields}} + {{ .Name | ToCamel }} {{ .Type | GoType }} +{{- end}} +} + +var _ message.Message = (*{{.Name | ToCamel}})(nil) + +{{define "marshal"}} +{{- if eqType .Type.Type "uint64"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertUint64ToBytes({{ .InputVar }})...) +{{- else if eqType .Type.Type "uint32"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertUint32ToBytes({{ .InputVar }})...) +{{- else if eqType .Type.Type "uint16"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertUint16ToBytes({{ .InputVar }})...) +{{- else if eqType .Type.Type "uint8"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertUint8ToBytes({{ .InputVar }})...) +{{- else if eqType .Type.Type "int64"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertInt64ToBytes({{ .InputVar }})...) +{{- else if eqType .Type.Type "int32"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertInt32ToBytes({{ .InputVar }})...) +{{- else if eqType .Type.Type "int16"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertInt16ToBytes({{ .InputVar }})...) +{{- else if eqType .Type.Type "int8"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertInt8ToBytes({{ .InputVar }})...) +{{- else if eqType .Type.Type "byte"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertByteToBytes({{ .InputVar }})...) +{{- else if eqType .Type.Type "bool"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertBoolToBytes({{ .InputVar }})...) +{{- else if eqType .Type.Type "string"}} + {{ .BufName }} = append({{ .BufName }},conv.ConvertSizeToBytes(len([]byte({{ .InputVar }})))...) + {{ .BufName }} = append({{ .BufName }},conv.ConvertStringToBytes({{ .InputVar }})...) +{{- else }} + {{ .BufName }} = append({{ .BufName }},{{ .InputVar }}.Marshal()...) +{{- end}} +{{- end}} + +func (o *{{.Name | ToCamel}}) Marshal() []byte { + var ( + arrBuf []byte + {{- /* TODO: precalculate size based on static fields */}} + b = make([]byte, 0, {{ mul (len .Fields) 16 }}) + ) + + size := conv.ConvertSizeToBytes(0) + b = append(b, size...) + + {{- range .Fields}} + {{- if not .Type.IsArray}} + {{- template "marshal" dict "Type" .Type "Name" .Name "BufName" "b" "InputVar" (printf "o.%s" (.Name | ToCamel))}} + {{- else}} + arrBuf = make([]byte, 0, 128) + {{- $inputVar := printf "el%s" (.Name | ToCamel) }} + for _, {{$inputVar}} := range o.{{.Name | ToCamel}} { + {{- template "marshal" dict "Type" .Type "Name" .Name "BufName" "arrBuf" "InputVar" $inputVar }} + } + {{- /*TODO: check if buf size exceess max payload bytes size: max(uint32) */}} + b = append(b, conv.ConvertSizeToBytes(len(arrBuf))...) + b = append(b, arrBuf...) + {{- end}} + {{- end}} + + size = conv.ConvertSizeToBytes(len(b)-len(size)) + for i := 0; i < len(size); i++ { + b[i] = size[i] + } + + return b +} + +{{define "unmarshal"}} +{{if eqType .Type.Type "uint64"}} + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(8) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToUint64(binaryCtx.buf) +{{ else if eqType .Type.Type "uint32"}} + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(4) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToUint32(binaryCtx.buf) +{{ else if eqType .Type.Type "uint16"}} + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(2) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToUint16(binaryCtx.buf) +{{ else if eqType .Type.Type "uint8"}} + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToUint8(binaryCtx.buf) +{{- else if eqType .Type.Type "int64"}} + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(8) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToInt64(binaryCtx.buf) +{{ else if eqType .Type.Type "int32"}} + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(4) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToInt32(binaryCtx.buf) +{{ else if eqType .Type.Type "int16"}} + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(2) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToInt16(binaryCtx.buf) +{{ else if eqType .Type.Type "int8"}} + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToInt8(binaryCtx.buf) +{{ else if eqType .Type.Type "byte"}} + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToByte(binaryCtx.buf) +{{ else if eqType .Type.Type "bool"}} + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(1) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToBool(binaryCtx.buf) +{{ else if eqType .Type.Type "string"}} + binaryCtx.size, binaryCtx.err = {{ .BufName }}.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }} size") + } + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + {{ .OutputVar }} = conv.ConvertBytesToString(binaryCtx.buf) +{{ else if eqType .Type.Type "object" }} + binaryCtx.size, binaryCtx.err = {{ .BufName }}.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }} size") + } + binaryCtx.buf, binaryCtx.err = {{ .BufName }}.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{ .Name | ToCamel }}") + } + if binaryCtx.err = {{ .OutputVar }}.Unmarshal(binaryCtx.buf); binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to unmarshal {{ .Name | ToCamel }}") + } +{{ end }} +{{ end}} + +func (o *{{.Name | ToCamel}}) Unmarshal(b *conv.BinaryIterator) error { + binaryCtx := struct { + err error + size, arrSize, pos int + buf, arrBuf *conv.BinaryIterator + }{} + + binaryCtx.err = nil + + {{- range .Fields}} + {{- if not .Type.IsArray}} + {{ $outputVar := printf "o.%s" (.Name | ToCamel) }} + {{- template "unmarshal" dict "Type" .Type "Name" .Name "BufName" "b" "OutputVar" $outputVar}} + {{- else}} + binaryCtx.size, binaryCtx.err = b.NextSize() + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{.Name | ToCamel}} size") + } + binaryCtx.arrBuf, binaryCtx.err = b.Slice(binaryCtx.size) + if binaryCtx.err != nil { + return errors.Wrap(binaryCtx.err, "failed to read {{.Name | ToCamel}}") + } + {{- if not (eq .Type.ArrSize 0) }} + binaryCtx.pos = 0 + {{- end}} + for binaryCtx.arrBuf.HasNext() { + {{- $outputVar := printf "el%s" (.Name | ToCamel) -}} + var {{$outputVar}} {{ GoType .Type true }} + {{- template "unmarshal" dict "Type" .Type "Name" .Name "BufName" "binaryCtx.arrBuf" "OutputVar" $outputVar }} + {{if eq .Type.ArrSize 0 -}} + o.{{.Name | ToCamel}} = append(o.{{.Name | ToCamel}}, {{$outputVar}}) + {{ else -}} + o.{{.Name | ToCamel}}[binaryCtx.pos] = {{$outputVar}} + binaryCtx.pos++ + {{- end}} + } + {{- end}} + {{- end}} + + return nil +} diff --git a/gen/templates/object_test.go b/gen/templates/object_test.go new file mode 100644 index 0000000..73041fe --- /dev/null +++ b/gen/templates/object_test.go @@ -0,0 +1,60 @@ +package templates + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/oringik/crypto-chateau/gen/ast" +) + +const exceptedObjectCode = `` + +func TestObjectTemplate_Gen(t *testing.T) { + ot, err := NewObjectTemplate() + + require.NoError(t, err, "failed to create object template") + + def := &ast.ObjectDefinition{ + Name: "MagicRequest", + Fields: []*ast.Field{ + { + Name: "MagicString", + Type: ast.TypeLink{ + Type: ast.String, + }, + }, + { + Name: "MagicUint32", + Type: ast.TypeLink{ + Type: ast.Uint32, + }, + }, + { + Name: "MagicBytes", + Type: ast.TypeLink{ + Type: ast.Byte, + IsArray: true, + ArrSize: 16, + }, + }, + { + Name: "MagicObjects", + Type: ast.TypeLink{ + Type: ast.Object, + ObjectName: "WonderObject", + IsArray: true, + }, + }, + }, + } + + code, err := ot.Gen(def) + + require.NoError(t, err, "failed to generate code") + + //err = os.WriteFile("object_example.go", []byte(code), 0644) + //require.NoError(t, err, "failed to save generated code") + + require.Equal(t, exceptedObjectCode, code, "generated code is not as expected") +} diff --git a/go.mod b/go.mod index 673f7f2..e71ca1f 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/oringik/crypto-chateau go 1.18 require ( + github.com/pkg/errors v0.8.1 github.com/stretchr/testify v1.8.1 go.uber.org/zap v1.23.0 golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b @@ -10,8 +11,10 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/iancoleman/strcase v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect + golang.org/x/text v0.5.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ce74fd3..91b0fed 100644 --- a/go.sum +++ b/go.sum @@ -2,7 +2,10 @@ github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLj github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/iancoleman/strcase v0.2.0 h1:05I4QRnGpI0m37iZQRuskXh+w77mr6Z41lwQzuHLwW0= +github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -22,6 +25,8 @@ go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY= go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY= golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b h1:huxqepDufQpLLIRXiVkTvnxrzJlpwmIWAObmcCcUFr0= golang.org/x/crypto v0.0.0-20221005025214-4161e89ecf1b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= +golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/message/message.go b/message/message.go index 4e8490c..e8e8d4e 100644 --- a/message/message.go +++ b/message/message.go @@ -1,6 +1,8 @@ package message +import "github.com/oringik/crypto-chateau/gen/conv" + type Message interface { Marshal() []byte - Unmarshal(map[string][]byte) error + Unmarshal(iterator *conv.BinaryIterator) error } diff --git a/peer/peer.go b/peer/peer.go index 396d013..b3cf3d6 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -1,10 +1,18 @@ package peer import ( + "errors" "fmt" + "net" + "github.com/oringik/crypto-chateau/gen/conv" + "github.com/oringik/crypto-chateau/gen/hash" "github.com/oringik/crypto-chateau/message" - "net" + "github.com/oringik/crypto-chateau/version" +) + +var ( + ErrBytesPrefix = [2]byte{0x2F, 0x20} ) type Peer struct { @@ -17,10 +25,11 @@ func NewPeer(conn net.Conn) *Peer { } } -func (p *Peer) WriteResponse(handlerName string, msg message.Message) error { +func (p *Peer) WriteResponse(handlerName hash.HandlerHash, msg message.Message) error { var resp []byte - resp = append(resp, []byte(handlerName+"#")...) + resp = append(resp, version.NewProtocolByte()) + resp = append(resp, handlerName[:]...) resp = append(resp, msg.Marshal()...) _, err := p.Conn.Write(resp) @@ -31,41 +40,25 @@ func (p *Peer) ReadMessage(msg message.Message) error { var msgRaw []byte _, err := p.Conn.Read(msgRaw) - if err != nil { - return err - } - _, n, err := conv.GetHandlerName(msgRaw) - if err != nil { - return err - } - - _, reqMsgParams, err := conv.GetParams(msgRaw[n:]) - if err != nil { - return err + if err != nil { + return fmt.Errorf("failed to read from connection: %w", err) } - - err = msg.Unmarshal(reqMsgParams) + _, _, offset, err := conv.GetHandler(msgRaw) if err != nil { - return err + return fmt.Errorf("failed get handler key: %w", err) } - return err -} - -func (p *Peer) ReadMessageClient(msg message.Message) error { - var msgRaw []byte - - _, err := p.Conn.Read(msgRaw) - if err != nil { - return err + // check if error prefix is present + if msgRaw[offset] == ErrBytesPrefix[0] && msgRaw[offset+1] == ErrBytesPrefix[1] { + return fmt.Errorf("chateau rpc: status = error, description = %s", string(msgRaw[offset+2:])) } - _, reqMsgParams, err := conv.GetParams(msgRaw) - if err != nil { - return err + // check if message has a size + if len(msgRaw) < offset+conv.ObjectBytesPrefixLength { + return errors.New("not enough for size and message") } - err = msg.Unmarshal(reqMsgParams) + err = msg.Unmarshal(conv.NewBinaryIterator(msgRaw[offset+conv.ObjectBytesPrefixLength:])) if err != nil { return err } @@ -73,10 +66,15 @@ func (p *Peer) ReadMessageClient(msg message.Message) error { return err } -func (p *Peer) WriteError(handlerName string, err error) error { - msg := fmt.Sprintf("%s# error: %s", handlerName, err.Error()) +func (p *Peer) WriteError(handlerKey hash.HandlerHash, err error) error { + var resp []byte + + resp = append(resp, version.NewProtocolByte()) + resp = append(resp, handlerKey[:]...) + resp = append(resp, ErrBytesPrefix[:]...) + resp = append(resp, []byte(err.Error())...) - _, writeErr := p.Conn.Write([]byte(msg)) + _, writeErr := p.Conn.Write(resp) return writeErr } diff --git a/server/server.go b/server/server.go index dac52b8..9da5b26 100644 --- a/server/server.go +++ b/server/server.go @@ -4,15 +4,18 @@ import ( "context" "errors" "fmt" - "github.com/oringik/crypto-chateau/gen/conv" - "github.com/oringik/crypto-chateau/message" - "github.com/oringik/crypto-chateau/peer" - "github.com/oringik/crypto-chateau/transport" - "go.uber.org/zap" "net" "strconv" "sync" "time" + + "go.uber.org/zap" + + "github.com/oringik/crypto-chateau/gen/conv" + "github.com/oringik/crypto-chateau/gen/hash" + "github.com/oringik/crypto-chateau/message" + "github.com/oringik/crypto-chateau/peer" + "github.com/oringik/crypto-chateau/transport" ) type HandlerFunc func(context.Context, message.Message) (message.Message, error) @@ -34,7 +37,7 @@ type Handler struct { type Server struct { Config *Config - Handlers map[string]*Handler + Handlers map[hash.HandlerHash]*Handler // key: ip address value: client peer Clients map[string]*peer.Peer shutdownCh chan error @@ -49,7 +52,7 @@ type Config struct { ConnWriteDeadline *time.Duration } -func NewServer(cfg *Config, logger *zap.Logger, handlers map[string]*Handler) *Server { +func NewServer(cfg *Config, logger *zap.Logger, handlers map[hash.HandlerHash]*Handler) *Server { return &Server{ Config: cfg, Handlers: handlers, @@ -145,28 +148,24 @@ func (s *Server) handleMethod(ctx context.Context, peer *peer.Peer) error { msg = append(msg, buf...) } - handlerName, n, err := conv.GetHandlerName(msg) + _, handlerKey, offset, err := conv.GetHandler(msg) if err != nil { return err } - handler, ok := s.Handlers[string(handlerName)] + handler, ok := s.Handlers[handlerKey] if !ok { - return errors.New("unknown handler " + string(handlerName)) + return errors.New(fmt.Sprintf("handler not found for key: %v", handlerKey)) } - if n >= len(msg) { - return errors.New("incorrect message") - } - - _, reqMsgParams, err := conv.GetParams(msg[n:]) - if err != nil { - return err + // check if message has a size + if len(msg) < offset+conv.ObjectBytesPrefixLength { + return errors.New("not enough bytes for size and message") } requestMsg := handler.RequestMsgType - err = requestMsg.Unmarshal(reqMsgParams) + err = requestMsg.Unmarshal(conv.NewBinaryIterator(msg[offset+conv.ObjectBytesPrefixLength:])) if err != nil { return err } @@ -175,16 +174,16 @@ func (s *Server) handleMethod(ctx context.Context, peer *peer.Peer) error { case HandlerT: responseMessage, err := handler.CallFuncHandler(ctx, requestMsg) if err != nil { - writeErr := peer.WriteError(string(handlerName), err) + writeErr := peer.WriteError(handlerKey, err) return writeErr } - err = peer.WriteResponse(string(handlerName), responseMessage) + err = peer.WriteResponse(handlerKey, responseMessage) if err != nil { return err } - if val, ok2 := s.Handlers[string(handlerName)].Tags["keep_conn_alive"]; !ok2 || val != "true" { + if val, ok2 := s.Handlers[handlerKey].Tags["keep_conn_alive"]; !ok2 || val != "true" { err = peer.Close() if err != nil { return err @@ -194,7 +193,7 @@ func (s *Server) handleMethod(ctx context.Context, peer *peer.Peer) error { go func() { err = handler.CallFuncStream(ctx, peer, requestMsg) if err != nil { - writeErr := peer.WriteError(string(handlerName), err) + writeErr := peer.WriteError(handlerKey, err) if writeErr != nil { fmt.Println(writeErr) } @@ -202,7 +201,7 @@ func (s *Server) handleMethod(ctx context.Context, peer *peer.Peer) error { return } - if val, ok2 := s.Handlers[string(handlerName)].Tags["keep_conn_alive"]; !ok2 || val != "true" { + if val, ok2 := s.Handlers[handlerKey].Tags["keep_conn_alive"]; !ok2 || val != "true" { err = peer.Close() if err != nil { fmt.Println(err) diff --git a/version/version.go b/version/version.go new file mode 100644 index 0000000..4f74ab2 --- /dev/null +++ b/version/version.go @@ -0,0 +1,10 @@ +package version + +const ( + ProtocolVersion byte = 0b00000001 // 4bits: min 0 max 7 + CodegenVersion = "1.0.0" // TODO: get from git on build +) + +func NewProtocolByte() byte { + return ProtocolVersion | (0b11110000) // first 4 bits reserved +}