From 12c0de92f9b698daea7fae48682f6fa20b04ac7f Mon Sep 17 00:00:00 2001 From: zachshattuck Date: Sun, 22 Sep 2024 18:11:41 -0600 Subject: [PATCH] Initial commit --- .github/workflows/go.yml | 28 +++++ README.md | 76 ++++++++++++ go.mod | 3 + http.go | 119 +++++++++++++++++++ http_test.go | 142 ++++++++++++++++++++++ websocket.go | 249 +++++++++++++++++++++++++++++++++++++++ websocket_test.go | 104 ++++++++++++++++ 7 files changed, 721 insertions(+) create mode 100644 .github/workflows/go.yml create mode 100644 README.md create mode 100644 go.mod create mode 100644 http.go create mode 100644 http_test.go create mode 100644 websocket.go create mode 100644 websocket_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..b8ad752 --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,28 @@ +# This workflow will build a golang project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go + +name: Go + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.23.1' + + - name: Build + run: go build -v ./... + + - name: Test + run: go test diff --git a/README.md b/README.md new file mode 100644 index 0000000..281c522 --- /dev/null +++ b/README.md @@ -0,0 +1,76 @@ +# `gows` +Simple WebSocket ([RFC6455](https://datatracker.ietf.org/doc/rfc6455/)) library in Go + +**What it offers:** +- Simple way to upgrade an HTTP connection to a WebSocket connection +- Simple way to serialize and deserialize individual WebSocket frames + +**What it doesn't do:** +- Doesn't handle message fragmentation, but you can do that yourself by reading `Fin` and `Opcode`. For more information, see section 5.4 of [RFC6455](https://datatracker.ietf.org/doc/rfc6455/) +- Doesn't automatically respond to PING frames. + +## Installation +`go get github.com/zachshattuck/gows` + +## Example Usage +```go +import ( + "github.com/zachshattuck/gows" +) + +func main() { + + ln, err := net.Listen("tcp", "127.0.0.1:8080") + if err != nil { + fmt.Println("Failed to `net.Listen`: ", err) + os.Exit(1) + } + + conn, err := ln.Accept() + if err != nil { + fmt.Println("Failed to `Accept` connection: ", err) + os.Exit(1) + } + + // Will `Read` from the connection and send a `101 Switching Protocols` response + // if valid, otherwise sends a `400 Bad Request` response. + err := gows.UpgradeConnection(&conn, buf) + if err != nil { + fmt.Fprintln(os.Stderr, "Failed to upgrade: ", err) + os.Exit(1) + } + + // Listen for WebSocket frames + for { + n, err := conn.Read(buf) + if err != nil { + fmt.Println("Failed to read: ", err) + break + } + + frame, err := gows.DeserializeWebSocketFrame(buf[:n]) + if err != nil { + fmt.Fprintln(os.Stderr, "Failed to deserialize frame: ", err) + continue + } + + switch frame.Opcode { + case gows.WS_OP_TEXT: // Handle text frame.. + case gows.WS_OP_BIN: // Handle binary frame.. + case gows.WS_OP_PING: + fmt.Println("Ping frame, responding with pong...") + pongFrame := gows.SerializeWebSocketFrame(gows.WebSocketFrame{ + Fin: 1, + Rsv1: 0, Rsv2: 0, Rsv3: 0, + Opcode: gows.WS_OP_PONG, + IsMasked: 0, + MaskKey: [4]byte{}, + Payload: frame.Payload, + }) + conn.Write(pongFrame) + } + + } + +} +``` \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..390fccb --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/zachshattuck/gows + +go 1.23.1 diff --git a/http.go b/http.go new file mode 100644 index 0000000..3c62e4a --- /dev/null +++ b/http.go @@ -0,0 +1,119 @@ +package gows + +import ( + "errors" + "net" +) + +/* +Given a param name and a buffer expected to be a valid HTTP request, this function +will return a slice containing the value of that HTTP param, if it is found. +*/ +func getHttpParam(buf []byte, paramName string) ([]byte, error) { + + // Read until we match `paramName` completely, NOT including the ":" + var correctByteCount int = 0 + var valueStartIdx int + for i, b := range buf { + if b != paramName[correctByteCount] { + correctByteCount = 0 + continue + } + + // Previous character has to be start of buffer or '\n' (as part of CRLF) + // NOTE: If the user provided a slice that was partway through a request, this could + // produce wrong results. For example, if there were two params, "Test-Param1: {value}" + // and "Param1: {value}", and the slice started at the 'P' in "Test-Param1", it could + // extract that value as if it was just "Param1". + if correctByteCount == 0 && !(i == 0 || buf[i-1] == '\n') { + correctByteCount = 0 + continue + } + + correctByteCount++ + + if correctByteCount < len(paramName) { + continue + } + + // Following character has to be ":" + if i >= len(buf)-2 || buf[i+1] != ':' { + correctByteCount = 0 + continue + } + + // we found the whole param! + valueStartIdx = i + 2 + break + } + + if correctByteCount < len(paramName) { + return nil, errors.New("param \"" + string(paramName) + "\" not found in buffer") + } + if valueStartIdx >= len(buf)-1 { + return nil, errors.New("nothing in buffer after \"" + string(paramName) + ":\"") + } + + // Read all whitespace + for { + if buf[valueStartIdx] != ' ' { + break + } + valueStartIdx++ + } + + // Read until CRLF + return readUntilCrlf(buf[valueStartIdx:]) +} + +/* Reads from start of slice until CRLF. If no CRLF is found, it will return an error instead of the value so far. */ +func readUntilCrlf(buf []byte) ([]byte, error) { + lastTokenIdx := -1337 + for i, b := range buf { + if b == '\r' { + lastTokenIdx = i + } else if b == '\n' { + if lastTokenIdx == i-1 { + return buf[:lastTokenIdx], nil + } + } + } + + // we never found a valid CRLF + return nil, errors.New("no CRLF found") +} + +func isValidUpgradeRequest(buf []byte) (bool, error) { + // TODO: This doesn't verify a valid HTTP verb at all + + // _, err := GetHttpParam(buf, "Host") + // if err != nil { + // return false, err + // } + + httpConnection, err := getHttpParam(buf, "Connection") + if err != nil || (string(httpConnection) != "Upgrade" && string(httpConnection) != "upgrade") { + return false, errors.New("invalid or nonexistent \"Connection\" param") + } + + httpUpgrade, err := getHttpParam(buf, "Upgrade") + if err != nil || string(httpUpgrade) != "websocket" { + return false, errors.New("invalid or nonexistent \"Upgrade\" param") + } + + httpWebSocketVersion, err := getHttpParam(buf, "Sec-WebSocket-Version") + if err != nil || string(httpWebSocketVersion) != "13" { + return false, errors.New("invalid or nonexistent \"Sec-WebSocket-Version\" param") + } + + _, err = getHttpParam(buf, "Sec-WebSocket-Key") + if err != nil { + return false, errors.New("invalid or nonexistent \"Sec-WebSocket-Key\" param") + } + + return true, nil +} + +func sendBadRequestResponse(conn *net.Conn) (int, error) { + return (*conn).Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) +} diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..b929921 --- /dev/null +++ b/http_test.go @@ -0,0 +1,142 @@ +package gows + +import "testing" + +/* Example WebSocket upgrade request, ripped straight from my browser. */ +var exampleHttpRequest = []byte("GET / HTTP/1.1\r\nHost: 127.0.0.1:8081\r\nConnection: Upgrade\r\nPragma: no-cache\r\nCache-Control: no-cache\r\nUser-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36\r\nUpgrade: websocket\r\nOrigin: http://localhost:8080\r\nSec-WebSocket-Version: 13\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\nSec-WebSocket-Key: D8KfDxohPIack4T9PAf3Ng==\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n\r\n") + +/* A longer WebSocket upgrade request, proxied by nginx. */ +var exampleHttpRequest2 = []byte("GET /ws HTTP/1.1\r\nUpgrade: websocket\r\nConnection: upgrade\r\nHost: 127.0.0.1:8081\r\naccept-encoding: gzip, br\r\nX-Forwarded-For: 1.2.3.4\r\nCF-RAY: 8c3d6a50b90875c8-SEA\r\nX-Forwarded-Proto: https\r\nCF-Visitor: {\"scheme\":\"https\"}\r\nPragma: no-cache\r\nCache-Control: no-cache\r\nUser-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36\r\nOrigin: https://www.website.com\r\nSec-WebSocket-Version: 13\r\nAccept-Language: en-US,en;q=0.9\r\nSec-WebSocket-Key: ZFPbTE+Wekp3z+QNUR4R0Q==\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nCF-Connecting-IP: 1.2.3.4\r\ncdn-loop: cloudflare; loops=1\r\nCF-IPCountry: US\r\n\r\n") + +func TestGetHttpParamValidProperty(t *testing.T) { + got, err := getHttpParam(exampleHttpRequest, "Host") + want := "127.0.0.1:8081" + + if err != nil { + t.Error("error:", err) + } + if string(got) != want { + t.Errorf("got %q, wanted %q", got, want) + } +} + +func TestGetHttpParamInvalidProperty(t *testing.T) { + got, err := getHttpParam(exampleHttpRequest, " super duper nonsensical parameter!!!! 89740r3n3yr0932") + + if err == nil { + t.Errorf("did not error, got %q", got) + } +} + +func TestGetHttpParamDuplicateKeyword(t *testing.T) { + got, err := getHttpParam(exampleHttpRequest, "Upgrade") + want := "websocket" + + if err != nil { + t.Error("error:", err) + } + if string(got) != want { + t.Errorf("got %q, wanted %q", got, want) + } +} + +func TestGetHttpParamExtremelyLongParameter(t *testing.T) { + got, err := getHttpParam(exampleHttpRequest, "sd09 fus8-d90f js09df mus90d8f mu09sd8fy um90s8d ynf098sd7f n908sd 7fn90s8d7fn 908sd7n f09s8df7n 908sd7f 098sd7nf 098sd7nf 098sd7nf 098sd7fn 098sd7nf 098sd7fn 098sd7f 098sdf7 n09s8df7n 098sdf7n 09s8df7 n09s8df 709s8df7 n09s8df7 098sdf7 098sdf yunoiusdf hlksjdfh klsjdfkjsdhfkj sdhjflksdjf lksdj f098sd7f 908sduf iujsdhf kjshdf kjysud9f8 7sd98f sdkjf h,sjdhf kjsdfy 98sdf iusdnf kjsdhf kiusdyf 98sdyfi uhsdifu ysd98f sd98f jsd98f jsd9f j9sd8f j9s8df hisudfh lkjsdhf8sdy f98sdhf iujsdhf iousdyuf 98sdhf oijsdhf likudsfyg s98ydfgisu hdfsiog hsdf98g y9sd8fgjh s9d8fg u9isd8fgy 0987sdfg yhioudsfhg oisudfgh o87sdfhg 9sdfgy h098sdfhg isdufhg 98sdfh g9087sdfhg iosdufhg osjkdfhg lkjdsfh giusdfug98dsfgu g9p8sdfjg ;lksdfj g") + want := "" + + if err == nil || string(got) != want { + t.Errorf("got %q, wanted %q", got, want) + } +} +func TestGetHttpParamSelf(t *testing.T) { + got, err := getHttpParam(exampleHttpRequest, "GET / HTTP/1.1\r\nHost: 127.0.0.1:8081\r\nConnection: Upgrade\r\nPragma: no-cache\r\nCache-Control: no-cache\r\nUser-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36\r\nUpgrade: websocket\r\nOrigin: http://localhost:8080\r\nSec-WebSocket-Version: 13\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\nSec-WebSocket-Key: D8KfDxohPIack4T9PAf3Ng==\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n\r\n") + + if err == nil { + t.Errorf("did not error, got %q", got) + } +} + +func TestReadUntilCrlfHttpVerb(t *testing.T) { + got, err := readUntilCrlf(exampleHttpRequest) + want := "GET / HTTP/1.1" + + if err != nil { + t.Error("error:", err) + } + if string(got) != want { + t.Errorf("got %q, wanted %q", got, want) + } +} + +func TestReadUntilCrlfHttpParamValue(t *testing.T) { + got, err := readUntilCrlf(exampleHttpRequest[22:]) + want := "127.0.0.1:8081" + + if err != nil { + t.Error("error:", err) + } + if string(got) != want { + t.Errorf("got %q, wanted %q", got, want) + } +} + +func TestReadUntilCrlfRandomNoCrlf(t *testing.T) { + got, err := readUntilCrlf([]byte("This is a nice string and all, but it doesn't have a Crlf.")) + + if err == nil { + t.Errorf("did not error, got %q", got) + } +} + +func TestReadUntilCrlfRandomWithCrlf(t *testing.T) { + got, err := readUntilCrlf([]byte("This is a nice string and all, AND it has a Crlf.\r\n")) + want := "This is a nice string and all, AND it has a Crlf." + + if err != nil { + t.Error("error:", err) + } + if string(got) != want { + t.Errorf("got %q, wanted %q", got, want) + } +} + +func TestReadUntilCrlfCursed1(t *testing.T) { + got, err := readUntilCrlf([]byte("\r \n\n\n\n \n\n\n\n\r\n\n\n\n\n\n\n")) + want := "\r \n\n\n\n \n\n\n\n" + + if err != nil { + t.Error("error:", err) + } + if string(got) != want { + t.Errorf("got %q, wanted %q", got, want) + } +} +func TestReadUntilCrlfCursed2(t *testing.T) { + got, err := readUntilCrlf([]byte("\r\r\r\r\r\r\r\r \nr\n\n\n \n\n\n\n\n\n\n\n\n\n\n")) + + if err == nil { + t.Errorf("did not error, got %q", got) + } +} + +func TestIsValidUpgradeRequestBasicGood(t *testing.T) { + got, err := isValidUpgradeRequest(exampleHttpRequest) + + if err != nil { + t.Error("error:", err) + } + if got == false { + t.Errorf("got invalid, expected valid") + } +} + +func TestIsValidUpgradeRequestLongGood(t *testing.T) { + got, err := isValidUpgradeRequest(exampleHttpRequest2) + + if err != nil { + t.Error("error:", err) + } + if got == false { + t.Errorf("got invalid, expected valid") + } +} diff --git a/websocket.go b/websocket.go new file mode 100644 index 0000000..e487eaa --- /dev/null +++ b/websocket.go @@ -0,0 +1,249 @@ +package gows + +import ( + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "math" + "net" + "os" +) + +// https://datatracker.ietf.org/doc/rfc6455 + +/* -> request +REQUIRED: +- verb must be GET +- HTTP version must be at least 1.1 +- "Host" +- "Upgrade: websocket" +- "Connection: Upgrade" +- "Sec-WebSocket-Key" +- "Sec-WebSocket-Version: 13" (must be 13?) +- "Origin" if from browser +- +OPTIONAL: +- "Sec-WebSocket-Protocol" +- "Sec-WebSocket-Extensions" +- +*/ + +/* <- response +REQUIRED: +- "HTTP{version} 101 Switching Protocols" +- "Upgrade: websocket" +- "Connection: Upgrade" +- "Sec-WebSocket-Accept: {base64(sha1(key + magicString))}" + +OPTIONAL: +- "Sec-WebSocket-Protocol" +- "Sec-WebSocket-Extensions" +*/ + +/* <-> framing +client must always mask messages sent, if not server sends back opcode 1002 +server never masks + +*/ + +const WS_OP_CONT = 0x0 +const WS_OP_TEXT = 0x1 +const WS_OP_BIN = 0x2 + +// 3-7: reserved for future non-control frames + +const WS_OP_CLOSE = 0x8 +const WS_OP_PING = 0x9 +const WS_OP_PONG = 0xA + +// B-F reserved for future control frames + +const webSocketMagicString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + +/* +Reads the next message and tries to interpret it as a WebSocket upgrade request. +If invalid, will send a 400 Bad Request response. +If valid, will send a 101 Switching Protocols response. +*/ +func UpgradeConnection(conn *net.Conn, buf []byte) error { + n, err := (*conn).Read(buf) + if err != nil { + return errors.New("failed to read from connection") + } + + isValid, err := isValidUpgradeRequest(buf[:n]) + if err != nil || !isValid { + sendBadRequestResponse(conn) + return err + } + + httpWebSocketKey, err := getHttpParam(buf[:n], "Sec-WebSocket-Key") + if err != nil { + sendBadRequestResponse(conn) + return err + } + + // https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers#server_handshake_response + sha1Checksum := sha1.Sum([]byte(string(httpWebSocketKey) + webSocketMagicString)) + httpWebSocketAccept := base64.StdEncoding.EncodeToString(sha1Checksum[:]) + + // TODO: This should respond with the same HTTP and Sec-WebSocket-Version that the client set, instead of a hardcoded version + _, err = (*conn).Write([]byte("HTTP/1.1 101 Switching Protocols\r\nSec-WebSocket-Version: 13\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + httpWebSocketAccept + "\r\n\r\n")) + if err != nil { + fmt.Fprintln(os.Stderr, "failed to write upgrade response:", err) + return err + } + + return nil +} + +type WebSocketFrame struct { + Fin byte + Rsv1 byte + Rsv2 byte + Rsv3 byte + Opcode uint8 + IsMasked byte + MaskKey [4]byte + Payload []byte +} + +func DeserializeWebSocketFrame(buf []byte) (WebSocketFrame, error) { + + // cannot be a valid websocket frame without 2 bytes + if len(buf) < 2 { + return WebSocketFrame{}, errors.New("frame less than 2 bytes") + } + // fmt.Println("length of WebSocket frame:", len(buf)) + // byte 1: FIN, RSV1-3, OPCODE + // byte 2: mask, 7-bit length + + byte1 := buf[0] + fin := byte1 >> 7 + rsv1 := (byte1 & 0x40) >> 6 + rsv2 := (byte1 & 0x20) >> 5 + rsv3 := (byte1 & 0x10) >> 4 + opcode := byte1 & 0x0F + + byte2 := buf[1] + isMasked := byte2 >> 7 + payloadLen := uint64(byte2 & 0x7F) + + maskKeyStartIdx := 2 + + // Calculate extended payload length... + if payloadLen == 126 { + if len(buf) < 4 { + return WebSocketFrame{}, errors.New("frame not long enough to read extended payload length") + } + s := buf[2:4] + payloadLen = (uint64(s[0]) << 8) | uint64(s[1]) + maskKeyStartIdx = 4 + } else if payloadLen == 127 { + if len(buf) < 10 { + return WebSocketFrame{}, errors.New("frame not long enough to read extended payload length") + } + s := buf[2:10] + payloadLen = (uint64(s[0]) << 56) | + (uint64(s[1]) << 48) | + (uint64(s[2]) << 40) | + (uint64(s[4]) << 32) | + (uint64(s[5]) << 24) | + (uint64(s[6]) << 16) | + (uint64(s[7]) << 8) | + (uint64(s[8]) << 0) + maskKeyStartIdx = 10 + } + + payloadStartIdx := 0 + var maskKey [4]byte + if isMasked == 1 { + if len(buf) < maskKeyStartIdx+4 { + return WebSocketFrame{}, errors.New("frame not long enough to read mask") + } + // Is it necessary to copy the maskKey out of the original buffer to put into the deserialized struct? + // idk + copy(maskKey[:], buf[maskKeyStartIdx:maskKeyStartIdx+4]) + payloadStartIdx = maskKeyStartIdx + 4 + } + + if len(buf) < payloadStartIdx+int(payloadLen) { + return WebSocketFrame{}, errors.New("frame not long enough for payload of given length") + } + + payload := buf[payloadStartIdx : payloadStartIdx+int(payloadLen)] + + if isMasked == 1 { + for i, bite := range payload { + payload[i] = bite ^ maskKey[i%4] + } + } + + return WebSocketFrame{ + Payload: payload, + Fin: fin, + Rsv1: rsv1, + Rsv2: rsv2, + Rsv3: rsv3, + IsMasked: isMasked, + MaskKey: maskKey, + Opcode: opcode}, nil +} + +func SerializeWebSocketFrame(data WebSocketFrame) []byte { + + var frame []byte + + /** Size of the frame, not including the payload. */ + frameHeaderLen := uint64(2) + if data.IsMasked == 1 { + frameHeaderLen += 4 + } + + byte1 := (data.Fin << 7) | + ((data.Rsv1 << 6) & 0x40) | + ((data.Rsv2 << 5) & 0x20) | + ((data.Rsv3 << 4) & 0x10) | + data.Opcode&0x0F + + byte2 := ((data.IsMasked & 0x01) << 7) + + // Note: This will truncate any value larger than a 64-bit unsigned integer. + payloadLen := uint64(len(data.Payload)) + + // Create the frame based on how many bytes we need for payload length + if payloadLen >= 126 { + if payloadLen < math.MaxUint16 { //16-bit extended payload length + byte2 |= (0x7E) + frameHeaderLen += 2 + frame = make([]byte, frameHeaderLen+payloadLen) + binary.BigEndian.PutUint16(frame[2:], uint16(payloadLen)) + } else { // 64-bit extended payload length + byte2 |= (0x7F) + frameHeaderLen += 8 + frame = make([]byte, frameHeaderLen+payloadLen) + binary.BigEndian.PutUint64(frame[2:], payloadLen) + } + } else { + byte2 |= (byte(payloadLen) & 0x7F) + frame = make([]byte, frameHeaderLen+payloadLen) + } + + frame[0] = byte1 + frame[1] = byte2 + + // Copy the payload into the frame + copy(frame[frameHeaderLen:], data.Payload) + + // Copy the mask key into the frame and XOR-encrypt the payload + if data.IsMasked == 1 { + copy(frame[frameHeaderLen-4:], data.MaskKey[:]) + for i := range frame[frameHeaderLen:] { + frame[frameHeaderLen:][i] ^= data.MaskKey[i%4] + } + } + + return frame +} diff --git a/websocket_test.go b/websocket_test.go new file mode 100644 index 0000000..6fc9cc8 --- /dev/null +++ b/websocket_test.go @@ -0,0 +1,104 @@ +package gows + +import ( + "encoding/hex" + "reflect" + "testing" +) + +/* "Hello, world!" +0x818d0e21bf6d4644d301610d9f1a6153d3092f +{1 0 0 0 1 1 [14 33 191 109] [72 101 108 108 111 44 32 119 111 114 108 100 33]} +*/ + +/* 0x01020304 +0x82843283033333810037 +{1 0 0 0 1 2 [50 131 3 51] [1 2 3 4]} +*/ + +func TestDeserializeWebSocketFrameMaskedText(t *testing.T) { + got, err := DeserializeWebSocketFrame([]byte{0x81, 0x8d, 0x0e, 0x21, 0xbf, 0x6d, 0x46, 0x44, 0xd3, 0x01, 0x61, 0x0d, 0x9f, 0x1a, 0x61, 0x53, 0xd3, 0x09, 0x2f}) + want := WebSocketFrame{ + Payload: []byte("Hello, world!"), + Fin: 1, + Rsv1: 0, + Rsv2: 0, + Rsv3: 0, + IsMasked: 1, + MaskKey: [4]byte{14, 33, 191, 109}, + Opcode: 1} + + if err != nil { + t.Errorf("got error: %s", err) + } + + if got.Fin != want.Fin { + t.Errorf("[fin] got %d, wanted %d", got.Fin, want.Fin) + } + if got.Rsv1 != want.Rsv1 { + t.Errorf("[rsv1] got %d, wanted %d", got.Rsv1, want.Rsv1) + } + if got.Rsv2 != want.Rsv2 { + t.Errorf("[rsv2] got %d, wanted %d", got.Rsv2, want.Rsv2) + } + if got.Rsv3 != want.Rsv3 { + t.Errorf("[rsv3] got %d, wanted %d", got.Rsv3, want.Rsv3) + } + if got.Opcode != want.Opcode { + t.Errorf("[opcode] got %d, wanted %d", got.Opcode, want.Opcode) + } + if got.IsMasked != want.IsMasked { + t.Errorf("[isMasked] got %d, wanted %d", got.IsMasked, want.IsMasked) + } + + if string(got.Payload) != string(want.Payload) { + t.Errorf("[payload]\ngot: %q\nwant: %q", string(got.Payload), string(want.Payload)) + } + + if reflect.DeepEqual(got.MaskKey, want.MaskKey) == false { + t.Errorf("[maskKey]\ngot: 0x%s\nwant: 0x%s", hex.EncodeToString(got.MaskKey[:]), hex.EncodeToString(want.MaskKey[:])) + } + +} + +func TestSerializeWebSocketFrameMaskedText(t *testing.T) { + got := SerializeWebSocketFrame(WebSocketFrame{ + Payload: []byte("Hello, world!"), + Fin: 1, + Rsv1: 0, + Rsv2: 0, + Rsv3: 0, + IsMasked: 1, + MaskKey: [4]byte{14, 33, 191, 109}, + Opcode: 1}) + want := []byte{0x81, 0x8d, 0x0e, 0x21, 0xbf, 0x6d, 0x46, 0x44, 0xd3, 0x01, 0x61, 0x0d, 0x9f, 0x1a, 0x61, 0x53, 0xd3, 0x09, 0x2f} + + if len(got) != len(want) { + t.Errorf("got length %d, wanted %d", len(got), len(want)) + } + + if reflect.DeepEqual(got, want) == false { + t.Errorf("\ngot: 0x%s\nwant: 0x%s", hex.EncodeToString(got), hex.EncodeToString(want)) + } +} + +func TestSerializeWebSocketFrameMaskedBinary(t *testing.T) { + got := SerializeWebSocketFrame(WebSocketFrame{ + Payload: []byte{1, 2, 3, 4}, + Fin: 1, + Rsv1: 0, + Rsv2: 0, + Rsv3: 0, + IsMasked: 1, + MaskKey: [4]byte{50, 131, 3, 51}, + Opcode: 2}) + want := []byte{0x82, 0x84, 0x32, 0x83, 0x03, 0x33, 0x33, 0x81, 0x00, 0x37} + + if len(got) != len(want) { + t.Errorf("got length %d, wanted %d", len(got), len(want)) + } + + if reflect.DeepEqual(got, want) == false { + t.Errorf("\ngot: 0x%s\nwant: 0x%s", hex.EncodeToString(got), hex.EncodeToString(want)) + } +}