Skip to content

Commit

Permalink
Implemented a websocket distributor
Browse files Browse the repository at this point in the history
  • Loading branch information
stut committed May 26, 2024
1 parent 15402ec commit 9561470
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 0 deletions.
11 changes: 11 additions & 0 deletions pkg/v2/distributor/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Websocket distributor

This is completely dumb and just forwards messages to all downstreams.

Auth is done by the downstreams which will disconnect a socket if it fails to
auth, which will cause this to disconnect the upstream.

## Usage

Pass each `server:port` downstream on the command line.

39 changes: 39 additions & 0 deletions pkg/v2/distributor/cmd/run/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package main

import (
"flag"
"github.com/theprimeagen/vim-with-me/pkg/v2/distributor"
"os"
"strconv"
"log/slog"
)

func main() {
var port int
flag.IntVar(&port, "port", 0, "port to listen on")
flag.Parse()
downstreams := os.Args[1:]

if port == 0 {
portStr := os.Getenv("PORT")
if portStr == "" {
slog.Error("No port specified!")
os.Exit(1)
}
var err error
port, err = strconv.Atoi(portStr)
if err != nil {
slog.Error("Error converting port to int", "port", portStr, "err", err)
os.Exit(1)
}
}

authId := os.Getenv("AUTH_ID")
if authId == "" {
slog.Error("No auth id specified!")
os.Exit(1)
}

d := distributor.NewDistributor(port, authId, downstreams)
d.Run()
}
93 changes: 93 additions & 0 deletions pkg/v2/distributor/distributor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package distributor

import (
"fmt"
"github.com/gorilla/websocket"
"github.com/theprimeagen/vim-with-me/pkg/v2/assert"
"log"
"log/slog"
"net/http"
"sync"
)

type Distributor struct {
sync.Mutex

listenPort int
authId string
upstreamConnection *websocket.Conn
downstreams []string
downstreamConns []*Downstream
msgChan chan []byte
}

func NewDistributor(listenPort int, authId string, downstreams []string) *Distributor {
return &Distributor{
listenPort: listenPort,
authId: authId,
downstreams: downstreams,
}
}

func (d *Distributor) Run() {
for _, addr := range d.downstreams {
d.downstreamConns = append(d.downstreamConns, NewDownstream(addr))
}

http.HandleFunc("/ws", d.handleIncomingConnection)

http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Location", "https://vim-with.me/")
w.WriteHeader(http.StatusTemporaryRedirect)
_, _ = w.Write([]byte("Not here, <a href=\"https://vim-with.me/\">over here</a>!"))
})

addr := fmt.Sprintf("0.0.0.0:%d", d.listenPort)
slog.Warn("listening and serving http", "http", addr)
err := http.ListenAndServe(addr, nil)

log.Fatal(err)
}

func (d *Distributor) handleIncomingConnection(w http.ResponseWriter, r *http.Request) {
func() {
d.Lock()
defer d.Unlock()

if d.upstreamConnection != nil {
// One connection at a time!
slog.Warn("Rejected connection, already have one",
"remote", r.RemoteAddr)
w.WriteHeader(http.StatusTooManyRequests)
return
}

slog.Info("New upstream connection", "remote", r.RemoteAddr)

upgrader := websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
assert.NoError(err, "unable to upgrade connection")

d.upstreamConnection = c
}()

for {
mt, msg, err := d.upstreamConnection.ReadMessage()
if mt != websocket.BinaryMessage {
slog.Error("Upstream sent non-binary message, disconnecting")
break
}

if err != nil {
slog.Error("Upstream error, disconnecting", "err", err)
break
}

for _, downstream := range d.downstreamConns {
downstream.SendMessage(mt, msg)
}
}

_ = d.upstreamConnection.Close()
d.upstreamConnection = nil
}
59 changes: 59 additions & 0 deletions pkg/v2/distributor/downstream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package distributor

import (
"net/url"
"github.com/gorilla/websocket"
"github.com/theprimeagen/vim-with-me/pkg/v2/assert"
"log/slog"
)

type Downstream struct {
addr string
conn *websocket.Conn
}

func NewDownstream(addr string) *Downstream {
ds := &Downstream{
addr: addr,
}
go ds.Run()
return ds
}

func (ds *Downstream) Run() {
var err error
for {
// Reconnect if necessary
if ds.conn == nil {
slog.Info("Connecting to downstream server", "addr", ds.addr)
u := url.URL{Scheme: "ws", Host: ds.addr, Path: "/ws"}
ds.conn, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
assert.NoError(err, "unable to connect to downstream server "+ds.addr)
}

for {
// Discard any incoming messages
mt, _, err := ds.conn.ReadMessage()
if err != nil || mt != websocket.BinaryMessage {
// Reconnect
slog.Warn("Downstream connection closed", "addr", ds.addr, "err", err)
break
}
}

slog.Error("Downstream connection closed, reconnecting", "addr", ds.addr)

_ = ds.conn.Close()
ds.conn = nil
}
}

func (ds *Downstream) SendMessage(msgType int, msg []byte) {
err := ds.conn.WriteMessage(msgType, msg)
if err != nil {
// Reconnect
slog.Warn("Failed to send to downstream, closing", "addr", ds.addr, "err", err)
_ = ds.conn.Close()
ds.conn = nil
}
}

0 comments on commit 9561470

Please sign in to comment.