Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
zachshattuck committed Sep 23, 2024
0 parents commit 12c0de9
Show file tree
Hide file tree
Showing 7 changed files with 721 additions and 0 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This workflow will build a golang project
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go

name: Go

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:

build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.23.1'

- name: Build
run: go build -v ./...

- name: Test
run: go test
76 changes: 76 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# `gows`
Simple WebSocket ([RFC6455](https://datatracker.ietf.org/doc/rfc6455/)) library in Go

**What it offers:**
- Simple way to upgrade an HTTP connection to a WebSocket connection
- Simple way to serialize and deserialize individual WebSocket frames

**What it doesn't do:**
- Doesn't handle message fragmentation, but you can do that yourself by reading `Fin` and `Opcode`. For more information, see section 5.4 of [RFC6455](https://datatracker.ietf.org/doc/rfc6455/)
- Doesn't automatically respond to PING frames.

## Installation
`go get github.com/zachshattuck/gows`

## Example Usage
```go
import (
"github.com/zachshattuck/gows"
)

func main() {

ln, err := net.Listen("tcp", "127.0.0.1:8080")
if err != nil {
fmt.Println("Failed to `net.Listen`: ", err)
os.Exit(1)
}

conn, err := ln.Accept()
if err != nil {
fmt.Println("Failed to `Accept` connection: ", err)
os.Exit(1)
}

// Will `Read` from the connection and send a `101 Switching Protocols` response
// if valid, otherwise sends a `400 Bad Request` response.
err := gows.UpgradeConnection(&conn, buf)
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to upgrade: ", err)
os.Exit(1)
}

// Listen for WebSocket frames
for {
n, err := conn.Read(buf)
if err != nil {
fmt.Println("Failed to read: ", err)
break
}

frame, err := gows.DeserializeWebSocketFrame(buf[:n])
if err != nil {
fmt.Fprintln(os.Stderr, "Failed to deserialize frame: ", err)
continue
}

switch frame.Opcode {
case gows.WS_OP_TEXT: // Handle text frame..
case gows.WS_OP_BIN: // Handle binary frame..
case gows.WS_OP_PING:
fmt.Println("Ping frame, responding with pong...")
pongFrame := gows.SerializeWebSocketFrame(gows.WebSocketFrame{
Fin: 1,
Rsv1: 0, Rsv2: 0, Rsv3: 0,
Opcode: gows.WS_OP_PONG,
IsMasked: 0,
MaskKey: [4]byte{},
Payload: frame.Payload,
})
conn.Write(pongFrame)
}

}

}
```
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module github.com/zachshattuck/gows

go 1.23.1
119 changes: 119 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package gows

import (
"errors"
"net"
)

/*
Given a param name and a buffer expected to be a valid HTTP request, this function
will return a slice containing the value of that HTTP param, if it is found.
*/
func getHttpParam(buf []byte, paramName string) ([]byte, error) {

// Read until we match `paramName` completely, NOT including the ":"
var correctByteCount int = 0
var valueStartIdx int
for i, b := range buf {
if b != paramName[correctByteCount] {
correctByteCount = 0
continue
}

// Previous character has to be start of buffer or '\n' (as part of CRLF)
// NOTE: If the user provided a slice that was partway through a request, this could
// produce wrong results. For example, if there were two params, "Test-Param1: {value}"
// and "Param1: {value}", and the slice started at the 'P' in "Test-Param1", it could
// extract that value as if it was just "Param1".
if correctByteCount == 0 && !(i == 0 || buf[i-1] == '\n') {
correctByteCount = 0
continue
}

correctByteCount++

if correctByteCount < len(paramName) {
continue
}

// Following character has to be ":"
if i >= len(buf)-2 || buf[i+1] != ':' {
correctByteCount = 0
continue
}

// we found the whole param!
valueStartIdx = i + 2
break
}

if correctByteCount < len(paramName) {
return nil, errors.New("param \"" + string(paramName) + "\" not found in buffer")
}
if valueStartIdx >= len(buf)-1 {
return nil, errors.New("nothing in buffer after \"" + string(paramName) + ":\"")
}

// Read all whitespace
for {
if buf[valueStartIdx] != ' ' {
break
}
valueStartIdx++
}

// Read until CRLF
return readUntilCrlf(buf[valueStartIdx:])
}

/* Reads from start of slice until CRLF. If no CRLF is found, it will return an error instead of the value so far. */
func readUntilCrlf(buf []byte) ([]byte, error) {
lastTokenIdx := -1337
for i, b := range buf {
if b == '\r' {
lastTokenIdx = i
} else if b == '\n' {
if lastTokenIdx == i-1 {
return buf[:lastTokenIdx], nil
}
}
}

// we never found a valid CRLF
return nil, errors.New("no CRLF found")
}

func isValidUpgradeRequest(buf []byte) (bool, error) {
// TODO: This doesn't verify a valid HTTP verb at all

// _, err := GetHttpParam(buf, "Host")
// if err != nil {
// return false, err
// }

httpConnection, err := getHttpParam(buf, "Connection")
if err != nil || (string(httpConnection) != "Upgrade" && string(httpConnection) != "upgrade") {
return false, errors.New("invalid or nonexistent \"Connection\" param")
}

httpUpgrade, err := getHttpParam(buf, "Upgrade")
if err != nil || string(httpUpgrade) != "websocket" {
return false, errors.New("invalid or nonexistent \"Upgrade\" param")
}

httpWebSocketVersion, err := getHttpParam(buf, "Sec-WebSocket-Version")
if err != nil || string(httpWebSocketVersion) != "13" {
return false, errors.New("invalid or nonexistent \"Sec-WebSocket-Version\" param")
}

_, err = getHttpParam(buf, "Sec-WebSocket-Key")
if err != nil {
return false, errors.New("invalid or nonexistent \"Sec-WebSocket-Key\" param")
}

return true, nil
}

func sendBadRequestResponse(conn *net.Conn) (int, error) {
return (*conn).Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
}
142 changes: 142 additions & 0 deletions http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package gows

import "testing"

/* Example WebSocket upgrade request, ripped straight from my browser. */
var exampleHttpRequest = []byte("GET / HTTP/1.1\r\nHost: 127.0.0.1:8081\r\nConnection: Upgrade\r\nPragma: no-cache\r\nCache-Control: no-cache\r\nUser-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36\r\nUpgrade: websocket\r\nOrigin: http://localhost:8080\r\nSec-WebSocket-Version: 13\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\nSec-WebSocket-Key: D8KfDxohPIack4T9PAf3Ng==\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n\r\n")

/* A longer WebSocket upgrade request, proxied by nginx. */
var exampleHttpRequest2 = []byte("GET /ws HTTP/1.1\r\nUpgrade: websocket\r\nConnection: upgrade\r\nHost: 127.0.0.1:8081\r\naccept-encoding: gzip, br\r\nX-Forwarded-For: 1.2.3.4\r\nCF-RAY: 8c3d6a50b90875c8-SEA\r\nX-Forwarded-Proto: https\r\nCF-Visitor: {\"scheme\":\"https\"}\r\nPragma: no-cache\r\nCache-Control: no-cache\r\nUser-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36\r\nOrigin: https://www.website.com\r\nSec-WebSocket-Version: 13\r\nAccept-Language: en-US,en;q=0.9\r\nSec-WebSocket-Key: ZFPbTE+Wekp3z+QNUR4R0Q==\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nCF-Connecting-IP: 1.2.3.4\r\ncdn-loop: cloudflare; loops=1\r\nCF-IPCountry: US\r\n\r\n")

func TestGetHttpParamValidProperty(t *testing.T) {
got, err := getHttpParam(exampleHttpRequest, "Host")
want := "127.0.0.1:8081"

if err != nil {
t.Error("error:", err)
}
if string(got) != want {
t.Errorf("got %q, wanted %q", got, want)
}
}

func TestGetHttpParamInvalidProperty(t *testing.T) {
got, err := getHttpParam(exampleHttpRequest, " super duper nonsensical parameter!!!! 89740r3n3yr0932")

if err == nil {
t.Errorf("did not error, got %q", got)
}
}

func TestGetHttpParamDuplicateKeyword(t *testing.T) {
got, err := getHttpParam(exampleHttpRequest, "Upgrade")
want := "websocket"

if err != nil {
t.Error("error:", err)
}
if string(got) != want {
t.Errorf("got %q, wanted %q", got, want)
}
}

func TestGetHttpParamExtremelyLongParameter(t *testing.T) {
got, err := getHttpParam(exampleHttpRequest, "sd09 fus8-d90f js09df mus90d8f mu09sd8fy um90s8d ynf098sd7f n908sd 7fn90s8d7fn 908sd7n f09s8df7n 908sd7f 098sd7nf 098sd7nf 098sd7nf 098sd7fn 098sd7nf 098sd7fn 098sd7f 098sdf7 n09s8df7n 098sdf7n 09s8df7 n09s8df 709s8df7 n09s8df7 098sdf7 098sdf yunoiusdf hlksjdfh klsjdfkjsdhfkj sdhjflksdjf lksdj f098sd7f 908sduf iujsdhf kjshdf kjysud9f8 7sd98f sdkjf h,sjdhf kjsdfy 98sdf iusdnf kjsdhf kiusdyf 98sdyfi uhsdifu ysd98f sd98f jsd98f jsd9f j9sd8f j9s8df hisudfh lkjsdhf8sdy f98sdhf iujsdhf iousdyuf 98sdhf oijsdhf likudsfyg s98ydfgisu hdfsiog hsdf98g y9sd8fgjh s9d8fg u9isd8fgy 0987sdfg yhioudsfhg oisudfgh o87sdfhg 9sdfgy h098sdfhg isdufhg 98sdfh g9087sdfhg iosdufhg osjkdfhg lkjdsfh giusdfug98dsfgu g9p8sdfjg ;lksdfj g")
want := ""

if err == nil || string(got) != want {
t.Errorf("got %q, wanted %q", got, want)
}
}
func TestGetHttpParamSelf(t *testing.T) {
got, err := getHttpParam(exampleHttpRequest, "GET / HTTP/1.1\r\nHost: 127.0.0.1:8081\r\nConnection: Upgrade\r\nPragma: no-cache\r\nCache-Control: no-cache\r\nUser-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36\r\nUpgrade: websocket\r\nOrigin: http://localhost:8080\r\nSec-WebSocket-Version: 13\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\nSec-WebSocket-Key: D8KfDxohPIack4T9PAf3Ng==\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n\r\n")

if err == nil {
t.Errorf("did not error, got %q", got)
}
}

func TestReadUntilCrlfHttpVerb(t *testing.T) {
got, err := readUntilCrlf(exampleHttpRequest)
want := "GET / HTTP/1.1"

if err != nil {
t.Error("error:", err)
}
if string(got) != want {
t.Errorf("got %q, wanted %q", got, want)
}
}

func TestReadUntilCrlfHttpParamValue(t *testing.T) {
got, err := readUntilCrlf(exampleHttpRequest[22:])
want := "127.0.0.1:8081"

if err != nil {
t.Error("error:", err)
}
if string(got) != want {
t.Errorf("got %q, wanted %q", got, want)
}
}

func TestReadUntilCrlfRandomNoCrlf(t *testing.T) {
got, err := readUntilCrlf([]byte("This is a nice string and all, but it doesn't have a Crlf."))

if err == nil {
t.Errorf("did not error, got %q", got)
}
}

func TestReadUntilCrlfRandomWithCrlf(t *testing.T) {
got, err := readUntilCrlf([]byte("This is a nice string and all, AND it has a Crlf.\r\n"))
want := "This is a nice string and all, AND it has a Crlf."

if err != nil {
t.Error("error:", err)
}
if string(got) != want {
t.Errorf("got %q, wanted %q", got, want)
}
}

func TestReadUntilCrlfCursed1(t *testing.T) {
got, err := readUntilCrlf([]byte("\r \n\n\n\n \n\n\n\n\r\n\n\n\n\n\n\n"))
want := "\r \n\n\n\n \n\n\n\n"

if err != nil {
t.Error("error:", err)
}
if string(got) != want {
t.Errorf("got %q, wanted %q", got, want)
}
}
func TestReadUntilCrlfCursed2(t *testing.T) {
got, err := readUntilCrlf([]byte("\r\r\r\r\r\r\r\r \nr\n\n\n \n\n\n\n\n\n\n\n\n\n\n"))

if err == nil {
t.Errorf("did not error, got %q", got)
}
}

func TestIsValidUpgradeRequestBasicGood(t *testing.T) {
got, err := isValidUpgradeRequest(exampleHttpRequest)

if err != nil {
t.Error("error:", err)
}
if got == false {
t.Errorf("got invalid, expected valid")
}
}

func TestIsValidUpgradeRequestLongGood(t *testing.T) {
got, err := isValidUpgradeRequest(exampleHttpRequest2)

if err != nil {
t.Error("error:", err)
}
if got == false {
t.Errorf("got invalid, expected valid")
}
}
Loading

0 comments on commit 12c0de9

Please sign in to comment.