Skip to content

Commit

Permalink
add UpgradeV2 (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong authored Jun 22, 2024
1 parent 6eaa970 commit fc60bda
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 21 deletions.
9 changes: 8 additions & 1 deletion autobahn/config/fuzzingclient.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
{
"outdir": "./report/",
"servers": [
{
"agent": "global",
"url": "ws://localhost:9001/global",
"options": {
"version": 18
}
},
{
"agent": "no-context-takeover-decompression-and-compression-no-tls",
"url": "ws://localhost:9001/no-context-takeover-decompression-and-compression",
Expand Down Expand Up @@ -37,4 +44,4 @@
""
],
"exclude-agent-cases": {}
}
}
19 changes: 19 additions & 0 deletions autobahn/server/autobahn-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,24 @@ func echoReadTime(w http.ResponseWriter, r *http.Request) {
_ = c.ReadLoop()
}

var upgrade = quickws.NewUpgrade(
quickws.WithServerReplyPing(),
quickws.WithServerDecompression(),
quickws.WithServerIgnorePong(),
quickws.WithServerEnableUTF8Check(),
quickws.WithServerReadTimeout(5*time.Second),
)

func global(w http.ResponseWriter, r *http.Request) {
c, err := upgrade.UpgradeV2(w, r, &echoHandler{openWriteTimeout: true})
if err != nil {
fmt.Println("Upgrade fail:", err)
return
}

_ = c.ReadLoop()
}

func startTLSServer(mux *http.ServeMux) {

cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
Expand Down Expand Up @@ -167,6 +185,7 @@ func startServer(mux *http.ServeMux) {
func main() {
mux := &http.ServeMux{}
mux.HandleFunc("/timeout", echoReadTime)
mux.HandleFunc("/global", global)
mux.HandleFunc("/no-context-takeover-decompression", echoNoContextDecompression)
mux.HandleFunc("/no-context-takeover-decompression-and-compression", echoNoContextDecompressionAndCompression)
mux.HandleFunc("/context-takeover-decompression", echoContextTakeoverDecompression)
Expand Down
5 changes: 4 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,10 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) {
if err := conn.SetDeadline(time.Time{}); err != nil {
return nil, err
}
wsCon = newConn(conn, true /* client is true*/, &d.Config, fr, br)
if wsCon, err = newConn(conn, true /* client is true*/, &d.Config, fr, br); err != nil {
return nil, err
}
wsCon.pd = pd
wsCon.Callback = d.cb
return wsCon, nil
}
16 changes: 8 additions & 8 deletions common_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
// 0. CallbackFunc
func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ClientOption {
return func(o *DialOption) {
o.Callback = &funcToCallback{
o.cb = &funcToCallback{
onOpen: open,
onMessage: m,
onClose: c,
Expand All @@ -34,7 +34,7 @@ func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) Cli
// 配置服务端回调函数
func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ServerOption {
return func(o *ConnOption) {
o.Callback = &funcToCallback{
o.cb = &funcToCallback{
onOpen: open,
onMessage: m,
onClose: c,
Expand All @@ -46,14 +46,14 @@ func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) Ser
// 配置客户端callback
func WithClientCallback(cb Callback) ClientOption {
return func(o *DialOption) {
o.Callback = cb
o.cb = cb
}
}

// 配置服务端回调函数
func WithServerCallback(cb Callback) ServerOption {
return func(o *ConnOption) {
o.Callback = cb
o.cb = cb
}
}

Expand Down Expand Up @@ -90,14 +90,14 @@ func WithClientEnableUTF8Check() ClientOption {
// 仅仅配置OnMessae函数
func WithServerOnMessageFunc(cb OnMessageFunc) ServerOption {
return func(o *ConnOption) {
o.Callback = OnMessageFunc(cb)
o.cb = OnMessageFunc(cb)
}
}

// 仅仅配置OnMessae函数
func WithClientOnMessageFunc(cb OnMessageFunc) ClientOption {
return func(o *DialOption) {
o.Callback = OnMessageFunc(cb)
o.cb = OnMessageFunc(cb)
}
}

Expand Down Expand Up @@ -292,14 +292,14 @@ func WithClientReadTimeout(t time.Duration) ClientOption {
// 17.1 配置服务端OnClose
func WithServerOnCloseFunc(onClose func(c *Conn, err error)) ServerOption {
return func(o *ConnOption) {
o.Callback = OnCloseFunc(onClose)
o.cb = OnCloseFunc(onClose)
}
}

// 17.2 配置客户端OnClose
func WithClientOnCloseFunc(onClose func(c *Conn, err error)) ClientOption {
return func(o *DialOption) {
o.Callback = OnCloseFunc(onClose)
o.cb = OnCloseFunc(onClose)
}
}

Expand Down
4 changes: 2 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type DialerTimeout interface {
// 一种是声明一个全局的配置,后面不停使用。
// 另外一种是局部声明一个配置,然后使用WithXXX函数设置配置
type Config struct {
Callback
cb Callback
deflate.PermessageDeflateConf // 静态配置, 从WithXXX函数中获取
tcpNoDelay bool
replyPing bool // 开启自动回复
Expand All @@ -67,7 +67,7 @@ func (c *Config) initPayloadSize() int {

// 默认设置
func (c *Config) defaultSetting() error {
c.Callback = &DefCallback{}
c.cb = &DefCallback{}
c.maxDelayWriteNum = 10
c.windowsMultipleTimesPayloadSize = 1.0
c.delayWriteInitBufferSize = 8 * 1024
Expand Down
2 changes: 1 addition & 1 deletion config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestConfig_defaultSetting(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Config{
Callback: tt.fields.Callback,
cb: tt.fields.Callback,
tcpNoDelay: tt.fields.tcpNoDelay,
replyPing: tt.fields.replyPing,
ignorePong: tt.fields.ignorePong,
Expand Down
11 changes: 7 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type delayWrite struct {
type Conn struct {
fr fixedreader.FixedReader // 默认使用windows
c net.Conn // net.Conn
Callback // callback移至conn中
br *bufio.Reader // read和fr同时只能使用一个
*Config // config 可能是全局,也可能是局部初始化得来的
pd deflate.PermessageDeflateConf // permessageDeflate局部配置
Expand Down Expand Up @@ -87,18 +88,20 @@ func setNoDelay(c net.Conn, noDelay bool) error {
return nil
}

func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, br *bufio.Reader) *Conn {
_ = setNoDelay(c, conf.tcpNoDelay)
func newConn(c net.Conn, client bool, conf *Config, fr fixedreader.FixedReader, br *bufio.Reader) (wsCon *Conn, err error) {
if err = setNoDelay(c, conf.tcpNoDelay); err != nil {
return nil, err
}

con := &Conn{
wsCon = &Conn{
c: c,
client: client,
Config: conf,
fr: fr,
br: br,
}

return con
return wsCon, err
}

// 返回标准库的net.Conn
Expand Down
18 changes: 14 additions & 4 deletions upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ func NewUpgrade(opts ...ServerOption) *UpgradeServer {
}

func (u *UpgradeServer) Upgrade(w http.ResponseWriter, r *http.Request) (c *Conn, err error) {
return upgradeInner(w, r, &u.config)
return upgradeInner(w, r, &u.config, nil)
}

func (u *UpgradeServer) UpgradeV2(w http.ResponseWriter, r *http.Request, cb Callback) (c *Conn, err error) {
return upgradeInner(w, r, &u.config, cb)
}

func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *Conn, err error) {
Expand All @@ -54,10 +58,10 @@ func Upgrade(w http.ResponseWriter, r *http.Request, opts ...ServerOption) (c *C
for _, o := range opts {
o(&conf)
}
return upgradeInner(w, r, &conf.Config)
return upgradeInner(w, r, &conf.Config, nil)
}

func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config) (c *Conn, err error) {
func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config, cb Callback) (wsCon *Conn, err error) {
if ecode, err := checkRequest(r); err != nil {
http.Error(w, err.Error(), ecode)
return nil, err
Expand Down Expand Up @@ -125,9 +129,15 @@ func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config) (c *Conn
if err := conn.SetDeadline(time.Time{}); err != nil {
return nil, err
}
wsCon := newConn(conn, false, conf, fr, br)
if wsCon, err = newConn(conn, false, conf, fr, br); err != nil {
return nil, err
}

wsCon.pd = pd
wsCon.Callback = cb
if cb == nil {
wsCon.Callback = conf.cb
}
return wsCon, nil
}

Expand Down

0 comments on commit fc60bda

Please sign in to comment.