From fc60bdae7368d6c0b56af671642e6f7a506d1a49 Mon Sep 17 00:00:00 2001 From: guonaihong Date: Sat, 22 Jun 2024 20:25:17 +0800 Subject: [PATCH] add UpgradeV2 (#29) --- autobahn/config/fuzzingclient.json | 9 ++++++++- autobahn/server/autobahn-server.go | 19 +++++++++++++++++++ client.go | 5 ++++- common_options.go | 16 ++++++++-------- config.go | 4 ++-- config_test.go | 2 +- conn.go | 11 +++++++---- upgrade.go | 18 ++++++++++++++---- 8 files changed, 63 insertions(+), 21 deletions(-) diff --git a/autobahn/config/fuzzingclient.json b/autobahn/config/fuzzingclient.json index 9d2d4b4..c6986b9 100644 --- a/autobahn/config/fuzzingclient.json +++ b/autobahn/config/fuzzingclient.json @@ -1,6 +1,13 @@ { "outdir": "./report/", "servers": [ + { + "agent": "global", + "url": "ws://localhost:9001/global", + "options": { + "version": 18 + } + }, { "agent": "no-context-takeover-decompression-and-compression-no-tls", "url": "ws://localhost:9001/no-context-takeover-decompression-and-compression", @@ -37,4 +44,4 @@ "" ], "exclude-agent-cases": {} -} +} \ No newline at end of file diff --git a/autobahn/server/autobahn-server.go b/autobahn/server/autobahn-server.go index 49b03c6..8ebee53 100644 --- a/autobahn/server/autobahn-server.go +++ b/autobahn/server/autobahn-server.go @@ -136,6 +136,24 @@ func echoReadTime(w http.ResponseWriter, r *http.Request) { _ = c.ReadLoop() } +var upgrade = quickws.NewUpgrade( + quickws.WithServerReplyPing(), + quickws.WithServerDecompression(), + quickws.WithServerIgnorePong(), + quickws.WithServerEnableUTF8Check(), + quickws.WithServerReadTimeout(5*time.Second), +) + +func global(w http.ResponseWriter, r *http.Request) { + c, err := upgrade.UpgradeV2(w, r, &echoHandler{openWriteTimeout: true}) + if err != nil { + fmt.Println("Upgrade fail:", err) + return + } + + _ = c.ReadLoop() +} + func startTLSServer(mux *http.ServeMux) { cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) @@ -167,6 +185,7 @@ func startServer(mux *http.ServeMux) { func main() { mux := &http.ServeMux{} mux.HandleFunc("/timeout", echoReadTime) + mux.HandleFunc("/global", global) mux.HandleFunc("/no-context-takeover-decompression", echoNoContextDecompression) mux.HandleFunc("/no-context-takeover-decompression-and-compression", echoNoContextDecompressionAndCompression) mux.HandleFunc("/context-takeover-decompression", echoContextTakeoverDecompression) diff --git a/client.go b/client.go index 487ba73..8f3519a 100644 --- a/client.go +++ b/client.go @@ -287,7 +287,10 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) { if err := conn.SetDeadline(time.Time{}); err != nil { return nil, err } - wsCon = newConn(conn, true /* client is true*/, &d.Config, fr, br) + if wsCon, err = newConn(conn, true /* client is true*/, &d.Config, fr, br); err != nil { + return nil, err + } wsCon.pd = pd + wsCon.Callback = d.cb return wsCon, nil } diff --git a/common_options.go b/common_options.go index ad4f932..07bf978 100644 --- a/common_options.go +++ b/common_options.go @@ -23,7 +23,7 @@ import ( // 0. CallbackFunc func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ClientOption { return func(o *DialOption) { - o.Callback = &funcToCallback{ + o.cb = &funcToCallback{ onOpen: open, onMessage: m, onClose: c, @@ -34,7 +34,7 @@ func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) Cli // 配置服务端回调函数 func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ServerOption { return func(o *ConnOption) { - o.Callback = &funcToCallback{ + o.cb = &funcToCallback{ onOpen: open, onMessage: m, onClose: c, @@ -46,14 +46,14 @@ func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) Ser // 配置客户端callback func WithClientCallback(cb Callback) ClientOption { return func(o *DialOption) { - o.Callback = cb + o.cb = cb } } // 配置服务端回调函数 func WithServerCallback(cb Callback) ServerOption { return func(o *ConnOption) { - o.Callback = cb + o.cb = cb } } @@ -90,14 +90,14 @@ func WithClientEnableUTF8Check() ClientOption { // 仅仅配置OnMessae函数 func WithServerOnMessageFunc(cb OnMessageFunc) ServerOption { return func(o *ConnOption) { - o.Callback = OnMessageFunc(cb) + o.cb = OnMessageFunc(cb) } } // 仅仅配置OnMessae函数 func WithClientOnMessageFunc(cb OnMessageFunc) ClientOption { return func(o *DialOption) { - o.Callback = OnMessageFunc(cb) + o.cb = OnMessageFunc(cb) } } @@ -292,14 +292,14 @@ func WithClientReadTimeout(t time.Duration) ClientOption { // 17.1 配置服务端OnClose func WithServerOnCloseFunc(onClose func(c *Conn, err error)) ServerOption { return func(o *ConnOption) { - o.Callback = OnCloseFunc(onClose) + o.cb = OnCloseFunc(onClose) } } // 17.2 配置客户端OnClose func WithClientOnCloseFunc(onClose func(c *Conn, err error)) ClientOption { return func(o *DialOption) { - o.Callback = OnCloseFunc(onClose) + o.cb = OnCloseFunc(onClose) } } diff --git a/config.go b/config.go index 2383476..77500ea 100644 --- a/config.go +++ b/config.go @@ -41,7 +41,7 @@ type DialerTimeout interface { // 一种是声明一个全局的配置,后面不停使用。 // 另外一种是局部声明一个配置,然后使用WithXXX函数设置配置 type Config struct { - Callback + cb Callback deflate.PermessageDeflateConf // 静态配置, 从WithXXX函数中获取 tcpNoDelay bool replyPing bool // 开启自动回复 @@ -67,7 +67,7 @@ func (c *Config) initPayloadSize() int { // 默认设置 func (c *Config) defaultSetting() error { - c.Callback = &DefCallback{} + c.cb = &DefCallback{} c.maxDelayWriteNum = 10 c.windowsMultipleTimesPayloadSize = 1.0 c.delayWriteInitBufferSize = 8 * 1024 diff --git a/config_test.go b/config_test.go index 849cfdf..00c7ec0 100644 --- a/config_test.go +++ b/config_test.go @@ -67,7 +67,7 @@ func TestConfig_defaultSetting(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Config{ - Callback: tt.fields.Callback, + cb: tt.fields.Callback, tcpNoDelay: tt.fields.tcpNoDelay, replyPing: tt.fields.replyPing, ignorePong: tt.fields.ignorePong, diff --git a/conn.go b/conn.go index 0e2cceb..9ca5f61 100644 --- a/conn.go +++ b/conn.go @@ -58,6 +58,7 @@ type delayWrite struct { type Conn struct { fr fixedreader.FixedReader // 默认使用windows c net.Conn // net.Conn + Callback // callback移至conn中 br *bufio.Reader // read和fr同时只能使用一个 *Config // config 可能是全局,也可能是局部初始化得来的 pd deflate.PermessageDeflateConf // permessageDeflate局部配置 @@ -87,10 +88,12 @@ func setNoDelay(c net.Conn, noDelay bool) error { return nil } -func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, br *bufio.Reader) *Conn { - _ = setNoDelay(c, conf.tcpNoDelay) +func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, br *bufio.Reader) (wsCon *Conn, err error) { + if err = setNoDelay(c, conf.tcpNoDelay); err != nil { + return nil, err + } - con := &Conn{ + wsCon = &Conn{ c: c, client: client, Config: conf, @@ -98,7 +101,7 @@ func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, br: br, } - return con + return wsCon, err } // 返回标准库的net.Conn diff --git a/upgrade.go b/upgrade.go index d50af65..bda0dd1 100644 --- a/upgrade.go +++ b/upgrade.go @@ -43,7 +43,11 @@ func NewUpgrade(opts ...ServerOption) *UpgradeServer { } func (u *UpgradeServer) Upgrade(w http.ResponseWriter, r *http.Request) (c *Conn, err error) { - return upgradeInner(w, r, &u.config) + return upgradeInner(w, r, &u.config, nil) +} + +func (u *UpgradeServer) UpgradeV2(w http.ResponseWriter, r *http.Request, cb Callback) (c *Conn, err error) { + return upgradeInner(w, r, &u.config, cb) } func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *Conn, err error) { @@ -54,10 +58,10 @@ func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *C for _, o := range opts { o(&conf) } - return upgradeInner(w, r, &conf.Config) + return upgradeInner(w, r, &conf.Config, nil) } -func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config) (c *Conn, err error) { +func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config, cb Callback) (wsCon *Conn, err error) { if ecode, err := checkRequest(r); err != nil { http.Error(w, err.Error(), ecode) return nil, err @@ -125,9 +129,15 @@ func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config) (c *Conn if err := conn.SetDeadline(time.Time{}); err != nil { return nil, err } - wsCon := newConn(conn, false, conf, fr, br) + if wsCon, err = newConn(conn, false, conf, fr, br); err != nil { + return nil, err + } wsCon.pd = pd + wsCon.Callback = cb + if cb == nil { + wsCon.Callback = conf.cb + } return wsCon, nil }