diff --git a/pkg/v2/distributor/README.md b/pkg/v2/distributor/README.md new file mode 100644 index 0000000..bb121b9 --- /dev/null +++ b/pkg/v2/distributor/README.md @@ -0,0 +1,11 @@ +# Websocket distributor + +This is completely dumb and just forwards messages to all downstreams. + +Auth is done by the downstreams which will disconnect a socket if it fails to +auth, which will cause this to disconnect the upstream. + +## Usage + +Pass each `server:port` downstream on the command line. + diff --git a/pkg/v2/distributor/cmd/run/main.go b/pkg/v2/distributor/cmd/run/main.go new file mode 100644 index 0000000..117f0f1 --- /dev/null +++ b/pkg/v2/distributor/cmd/run/main.go @@ -0,0 +1,39 @@ +package main + +import ( + "flag" + "github.com/theprimeagen/vim-with-me/pkg/v2/distributor" + "os" + "strconv" + "log/slog" +) + +func main() { + var port int + flag.IntVar(&port, "port", 0, "port to listen on") + flag.Parse() + downstreams := os.Args[1:] + + if port == 0 { + portStr := os.Getenv("PORT") + if portStr == "" { + slog.Error("No port specified!") + os.Exit(1) + } + var err error + port, err = strconv.Atoi(portStr) + if err != nil { + slog.Error("Error converting port to int", "port", portStr, "err", err) + os.Exit(1) + } + } + + authId := os.Getenv("AUTH_ID") + if authId == "" { + slog.Error("No auth id specified!") + os.Exit(1) + } + + d := distributor.NewDistributor(port, authId, downstreams) + d.Run() +} diff --git a/pkg/v2/distributor/distributor.go b/pkg/v2/distributor/distributor.go new file mode 100644 index 0000000..2fb06b3 --- /dev/null +++ b/pkg/v2/distributor/distributor.go @@ -0,0 +1,93 @@ +package distributor + +import ( + "fmt" + "github.com/gorilla/websocket" + "github.com/theprimeagen/vim-with-me/pkg/v2/assert" + "log" + "log/slog" + "net/http" + "sync" +) + +type Distributor struct { + sync.Mutex + + listenPort int + authId string + upstreamConnection *websocket.Conn + downstreams []string + downstreamConns []*Downstream + msgChan chan []byte +} + +func NewDistributor(listenPort int, authId string, downstreams []string) *Distributor { + return &Distributor{ + listenPort: listenPort, + authId: authId, + downstreams: downstreams, + } +} + +func (d *Distributor) Run() { + for _, addr := range d.downstreams { + d.downstreamConns = append(d.downstreamConns, NewDownstream(addr)) + } + + http.HandleFunc("/ws", d.handleIncomingConnection) + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Location", "https://vim-with.me/") + w.WriteHeader(http.StatusTemporaryRedirect) + _, _ = w.Write([]byte("Not here, over here!")) + }) + + addr := fmt.Sprintf("0.0.0.0:%d", d.listenPort) + slog.Warn("listening and serving http", "http", addr) + err := http.ListenAndServe(addr, nil) + + log.Fatal(err) +} + +func (d *Distributor) handleIncomingConnection(w http.ResponseWriter, r *http.Request) { + func() { + d.Lock() + defer d.Unlock() + + if d.upstreamConnection != nil { + // One connection at a time! + slog.Warn("Rejected connection, already have one", + "remote", r.RemoteAddr) + w.WriteHeader(http.StatusTooManyRequests) + return + } + + slog.Info("New upstream connection", "remote", r.RemoteAddr) + + upgrader := websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + assert.NoError(err, "unable to upgrade connection") + + d.upstreamConnection = c + }() + + for { + mt, msg, err := d.upstreamConnection.ReadMessage() + if mt != websocket.BinaryMessage { + slog.Error("Upstream sent non-binary message, disconnecting") + break + } + + if err != nil { + slog.Error("Upstream error, disconnecting", "err", err) + break + } + + for _, downstream := range d.downstreamConns { + downstream.SendMessage(mt, msg) + } + } + + _ = d.upstreamConnection.Close() + d.upstreamConnection = nil +} diff --git a/pkg/v2/distributor/downstream.go b/pkg/v2/distributor/downstream.go new file mode 100644 index 0000000..c984458 --- /dev/null +++ b/pkg/v2/distributor/downstream.go @@ -0,0 +1,59 @@ +package distributor + +import ( + "net/url" + "github.com/gorilla/websocket" + "github.com/theprimeagen/vim-with-me/pkg/v2/assert" + "log/slog" +) + +type Downstream struct { + addr string + conn *websocket.Conn +} + +func NewDownstream(addr string) *Downstream { + ds := &Downstream{ + addr: addr, + } + go ds.Run() + return ds +} + +func (ds *Downstream) Run() { + var err error + for { + // Reconnect if necessary + if ds.conn == nil { + slog.Info("Connecting to downstream server", "addr", ds.addr) + u := url.URL{Scheme: "ws", Host: ds.addr, Path: "/ws"} + ds.conn, _, err = websocket.DefaultDialer.Dial(u.String(), nil) + assert.NoError(err, "unable to connect to downstream server "+ds.addr) + } + + for { + // Discard any incoming messages + mt, _, err := ds.conn.ReadMessage() + if err != nil || mt != websocket.BinaryMessage { + // Reconnect + slog.Warn("Downstream connection closed", "addr", ds.addr, "err", err) + break + } + } + + slog.Error("Downstream connection closed, reconnecting", "addr", ds.addr) + + _ = ds.conn.Close() + ds.conn = nil + } +} + +func (ds *Downstream) SendMessage(msgType int, msg []byte) { + err := ds.conn.WriteMessage(msgType, msg) + if err != nil { + // Reconnect + slog.Warn("Failed to send to downstream, closing", "addr", ds.addr, "err", err) + _ = ds.conn.Close() + ds.conn = nil + } +}