-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
202 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Websocket distributor | ||
|
||
This is completely dumb and just forwards messages to all downstreams. | ||
|
||
Auth is done by the downstreams which will disconnect a socket if it fails to | ||
auth, which will cause this to disconnect the upstream. | ||
|
||
## Usage | ||
|
||
Pass each `server:port` downstream on the command line. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package main | ||
|
||
import ( | ||
"flag" | ||
"github.com/theprimeagen/vim-with-me/pkg/v2/distributor" | ||
"os" | ||
"strconv" | ||
"log/slog" | ||
) | ||
|
||
func main() { | ||
var port int | ||
flag.IntVar(&port, "port", 0, "port to listen on") | ||
flag.Parse() | ||
downstreams := os.Args[1:] | ||
|
||
if port == 0 { | ||
portStr := os.Getenv("PORT") | ||
if portStr == "" { | ||
slog.Error("No port specified!") | ||
os.Exit(1) | ||
} | ||
var err error | ||
port, err = strconv.Atoi(portStr) | ||
if err != nil { | ||
slog.Error("Error converting port to int", "port", portStr, "err", err) | ||
os.Exit(1) | ||
} | ||
} | ||
|
||
authId := os.Getenv("AUTH_ID") | ||
if authId == "" { | ||
slog.Error("No auth id specified!") | ||
os.Exit(1) | ||
} | ||
|
||
d := distributor.NewDistributor(port, authId, downstreams) | ||
d.Run() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package distributor | ||
|
||
import ( | ||
"fmt" | ||
"github.com/gorilla/websocket" | ||
"github.com/theprimeagen/vim-with-me/pkg/v2/assert" | ||
"log" | ||
"log/slog" | ||
"net/http" | ||
"sync" | ||
) | ||
|
||
type Distributor struct { | ||
sync.Mutex | ||
|
||
listenPort int | ||
authId string | ||
upstreamConnection *websocket.Conn | ||
downstreams []string | ||
downstreamConns []*Downstream | ||
msgChan chan []byte | ||
} | ||
|
||
func NewDistributor(listenPort int, authId string, downstreams []string) *Distributor { | ||
return &Distributor{ | ||
listenPort: listenPort, | ||
authId: authId, | ||
downstreams: downstreams, | ||
} | ||
} | ||
|
||
func (d *Distributor) Run() { | ||
for _, addr := range d.downstreams { | ||
d.downstreamConns = append(d.downstreamConns, NewDownstream(addr)) | ||
} | ||
|
||
http.HandleFunc("/ws", d.handleIncomingConnection) | ||
|
||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | ||
w.Header().Add("Location", "https://vim-with.me/") | ||
w.WriteHeader(http.StatusTemporaryRedirect) | ||
_, _ = w.Write([]byte("Not here, <a href=\"https://vim-with.me/\">over here</a>!")) | ||
}) | ||
|
||
addr := fmt.Sprintf("0.0.0.0:%d", d.listenPort) | ||
slog.Warn("listening and serving http", "http", addr) | ||
err := http.ListenAndServe(addr, nil) | ||
|
||
log.Fatal(err) | ||
} | ||
|
||
func (d *Distributor) handleIncomingConnection(w http.ResponseWriter, r *http.Request) { | ||
func() { | ||
d.Lock() | ||
defer d.Unlock() | ||
|
||
if d.upstreamConnection != nil { | ||
// One connection at a time! | ||
slog.Warn("Rejected connection, already have one", | ||
"remote", r.RemoteAddr) | ||
w.WriteHeader(http.StatusTooManyRequests) | ||
return | ||
} | ||
|
||
slog.Info("New upstream connection", "remote", r.RemoteAddr) | ||
|
||
upgrader := websocket.Upgrader{} | ||
c, err := upgrader.Upgrade(w, r, nil) | ||
assert.NoError(err, "unable to upgrade connection") | ||
|
||
d.upstreamConnection = c | ||
}() | ||
|
||
for { | ||
mt, msg, err := d.upstreamConnection.ReadMessage() | ||
if mt != websocket.BinaryMessage { | ||
slog.Error("Upstream sent non-binary message, disconnecting") | ||
break | ||
} | ||
|
||
if err != nil { | ||
slog.Error("Upstream error, disconnecting", "err", err) | ||
break | ||
} | ||
|
||
for _, downstream := range d.downstreamConns { | ||
downstream.SendMessage(mt, msg) | ||
} | ||
} | ||
|
||
_ = d.upstreamConnection.Close() | ||
d.upstreamConnection = nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package distributor | ||
|
||
import ( | ||
"net/url" | ||
"github.com/gorilla/websocket" | ||
"github.com/theprimeagen/vim-with-me/pkg/v2/assert" | ||
"log/slog" | ||
) | ||
|
||
type Downstream struct { | ||
addr string | ||
conn *websocket.Conn | ||
} | ||
|
||
func NewDownstream(addr string) *Downstream { | ||
ds := &Downstream{ | ||
addr: addr, | ||
} | ||
go ds.Run() | ||
return ds | ||
} | ||
|
||
func (ds *Downstream) Run() { | ||
var err error | ||
for { | ||
// Reconnect if necessary | ||
if ds.conn == nil { | ||
slog.Info("Connecting to downstream server", "addr", ds.addr) | ||
u := url.URL{Scheme: "ws", Host: ds.addr, Path: "/ws"} | ||
ds.conn, _, err = websocket.DefaultDialer.Dial(u.String(), nil) | ||
assert.NoError(err, "unable to connect to downstream server "+ds.addr) | ||
} | ||
|
||
for { | ||
// Discard any incoming messages | ||
mt, _, err := ds.conn.ReadMessage() | ||
if err != nil || mt != websocket.BinaryMessage { | ||
// Reconnect | ||
slog.Warn("Downstream connection closed", "addr", ds.addr, "err", err) | ||
break | ||
} | ||
} | ||
|
||
slog.Error("Downstream connection closed, reconnecting", "addr", ds.addr) | ||
|
||
_ = ds.conn.Close() | ||
ds.conn = nil | ||
} | ||
} | ||
|
||
func (ds *Downstream) SendMessage(msgType int, msg []byte) { | ||
err := ds.conn.WriteMessage(msgType, msg) | ||
if err != nil { | ||
// Reconnect | ||
slog.Warn("Failed to send to downstream, closing", "addr", ds.addr, "err", err) | ||
_ = ds.conn.Close() | ||
ds.conn = nil | ||
} | ||
} |