diff --git a/conn.go b/conn.go index fe2a5c7..78628a8 100644 --- a/conn.go +++ b/conn.go @@ -3,6 +3,7 @@ package shlev import ( "bytes" "golang.org/x/sys/unix" + "io" "net" "shlev/internal/netpoll" ) @@ -53,14 +54,69 @@ func (c *Conn) open(buf []byte) error { return err } +// 读数据 +func (c *Conn) Read(p []byte) (n int, err error) { + if c.recvBuffer.Len() == 0 { + n = copy(p, c.buffer) + c.buffer = c.buffer[n:] + if n == 0 && len(p) > 0 { + err = io.EOF + } + return n, err + } + + n, _ = c.recvBuffer.Read(p) + if n == len(p) { + return + } + + m := copy(p[n:], c.buffer) + n += m + c.buffer = c.buffer[m:] + return n, err +} + +// 写数据 +func (c *Conn) Write(data []byte) (n int, err error) { + n = len(data) + + // 连接发送缓冲区不为0时,说明此时套接字的发送缓冲区已经满了,没有必要向套接字写。 + if c.sendBuffer.Len() != 0 { + c.sendBuffer.Write(data) + return n, nil + } + + var send int + if send, err = unix.Write(c.fd, data); err != nil { + // 写入错误,释放内存,关闭连接 + return -1, c.loop.closeConnection(c) + } + + // 当套接字写缓冲区写满时,写入连接的发送缓冲区 + if send < n { + c.sendBuffer.Write(data[send:]) + // 监听写事件 + err = c.loop.netpoll.ModReadWrite(c.fd) + } + return n, err +} + +// TODO AsyncWrite + // 创建新的tcp连接 func newTCPConn(fd int, e *EventLoop, sa unix.Sockaddr, localAddr, remoteAddr net.Addr) (c *Conn) { c = &Conn{ fd: fd, + lnIndex: 0, + context: nil, remotePeer: sa, - loop: e, localAddr: localAddr, remoteAddr: remoteAddr, + loop: e, + buffer: nil, + recvBuffer: bytes.NewBuffer(make([]byte, e.server.opts.SocketRecvBuffer)), + sendBuffer: bytes.NewBuffer(make([]byte, e.server.opts.SocketSendBuffer)), + opened: false, } c.sendBuffer = bytes.NewBuffer(make([]byte, 0)) return diff --git a/event_loop.go b/event_loop.go index 19ec208..69342b0 100644 --- a/event_loop.go +++ b/event_loop.go @@ -48,10 +48,6 @@ func (e *EventLoop) closeConnection(c *Conn) (err error) { if !c.opened { return } - // 连接 - if addr := c.localAddr; addr != nil { - return - } // 如果发送缓冲不为空,说明还有数据要发送,需要先发送完数据再关闭连接 if c.sendBuffer.Len() != 0 { @@ -118,16 +114,17 @@ func (e *EventLoop) open(c *Conn) error { return e.handleResult(c, result) } +// 封装read系统调用 func (e *EventLoop) read(c *Conn) error { n, err := unix.Read(c.fd, e.buffer) - if err == nil || n == 0 { + if err != nil || n == 0 { if err == unix.EAGAIN { return nil } if n == 0 { err = unix.ECONNRESET } - logger.Error(fmt.Sprintf("EventLoop event_loop idx:%d read err:%v", c.fd, os.NewSyscallError("read", err))) + logger.Error(fmt.Sprintf("EventLoop event_loop fd:%d read err:%v", c.fd, os.NewSyscallError("read", err))) return e.closeConnection(c) } @@ -267,6 +264,7 @@ func (e *EventLoop) run(lockOSThread bool) { return err } } + // 当套接字有 unix.EPOLLIN 事件,且读到的数据长度为0时,说明对方已经关闭连接。 if (ev & netpoll.InEvents) != 0 { return e.read(c) } diff --git a/internal/netpoll/Inetpoll.go b/internal/netpoll/Inetpoll.go index a07976a..632db26 100644 --- a/internal/netpoll/Inetpoll.go +++ b/internal/netpoll/Inetpoll.go @@ -11,6 +11,8 @@ type Netpoller interface { AddWrite(fd int) error // ModRead 改为读 ModRead(fd int) error + // ModReadWrite 读写事件 + ModReadWrite(fd int) error // Polling 轮询事件 Polling(func(int, uint32) error) error // Close 关闭事件循环 diff --git a/internal/netpoll/epoll_events.go b/internal/netpoll/epoll_events.go index 5d774e6..004b5c6 100644 --- a/internal/netpoll/epoll_events.go +++ b/internal/netpoll/epoll_events.go @@ -14,7 +14,7 @@ import "golang.org/x/sys/unix" */ const ( - readEvents = unix.EPOLLPRI | unix.EPOLLIN + readEvents = unix.EPOLLIN writeEvents = unix.EPOLLOUT readWriteEvents = readEvents | writeEvents @@ -35,10 +35,10 @@ const ( ** 有写需要时才通过epoll_ctl添加相应fd,不然在LT模式下会频繁触发; ** 对于写操作,大部分情况下都处于可写状态,可先直接调用write来发送数据,直到返回 EAGAIN后再使能EPOLLOUT,待触发后再继续write。 */ - // OutEvents combines EPOLLOUT event and some exceptional events. + // OutEvents 包含错误,挂断以及可写事件 OutEvents = ErrEvents | unix.EPOLLOUT - // InEvents combines EPOLLIN/EPOLLPRI events and some exceptional events. - InEvents = ErrEvents | unix.EPOLLIN | unix.EPOLLPRI + // InEvents 包含错误,挂断以及可读事件 + InEvents = ErrEvents | unix.EPOLLIN ) const ( diff --git a/shlev_test.go b/shlev_test.go index 8b38572..7e99735 100644 --- a/shlev_test.go +++ b/shlev_test.go @@ -1,19 +1,24 @@ package shlev import ( + "context" "fmt" "golang.org/x/sys/unix" "shlev/tools/logger" "testing" + "time" ) func TestServer(t *testing.T) { s := &testServer{} fmt.Println("server run") - Run(s, "47.103.116.215:10001", WithNumEventLoop(3), WithLoadBalancing(RoundRobin)) - for true { - - } + go func() { + fmt.Println("server after run") + time.Sleep(time.Second * 30) + fmt.Println("server close") + Stop(context.Background(), "127.0.0.1:10001") + }() + Run(s, "127.0.0.1:10001", WithNumEventLoop(3), WithLoadBalancing(RoundRobin)) } type testServer struct { @@ -32,14 +37,24 @@ func (s *testServer) OnOpen(c *Conn, err error) (b []byte, e HandleResult) { c.SetContext(c) logger.Debug("OnOpen localAddr:", c.LocalAddr(), "; remoteAddr:", c.RemoteAddr()) unix.Write(c.fd, []byte("fuck off\n")) - unix.Close(c.fd) return []byte{}, None } -func (s *testServer) OnConnectionClose(_ *Conn, _ error) { - //s.eng = eng +func (s *testServer) OnConnectionClose(c *Conn, _ error) { + if c.recvBuffer.Len() != 0 { + b := c.recvBuffer.Bytes() + fmt.Println(string(b)) + } + //logger.Debug("OnConnectionClose localAddr:", c.LocalAddr(), "; remoteAddr:", c.RemoteAddr()) return } -func (s *testServer) OnTraffic(_ *Conn) HandleResult { + +func (s *testServer) OnTraffic(c *Conn) HandleResult { + b := make([]byte, 100000) + n, err := c.Read(b) + if err != nil { + return 0 + } + fmt.Println("read data:", string(b[:n])) return None } diff --git a/tools/logger/logger.go b/tools/logger/logger.go index ed2b87e..fcc274e 100644 --- a/tools/logger/logger.go +++ b/tools/logger/logger.go @@ -16,7 +16,7 @@ func init() { panic(err) } - f, err := os.OpenFile(path+"/net.log", os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0666) + f, err := os.OpenFile(path+"/shlev_net.log", os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0666) if err != nil { panic(err) } @@ -42,26 +42,51 @@ func Init(v ...any) { logger.Println(v...) } +func DebugF(fmt string, v ...any) { + setPrefix("DEBUG") + logger.Printf(fmt, v...) +} + func Debug(v ...any) { setPrefix("DEBUG") logger.Println(v...) } +func WarnF(fmt string, v ...any) { + setPrefix("WARN") + logger.Printf(fmt, v...) +} + func Warn(v ...any) { setPrefix("WARN") logger.Println(v...) } +func ErrorF(fmt string, v ...any) { + setPrefix("ERROR") + logger.Printf(fmt, v...) +} + func Error(v ...any) { setPrefix("ERROR") logger.Println(v...) } +func InfoF(fmt string, v ...any) { + setPrefix("INFO") + logger.Printf(fmt, v...) +} + func Info(v ...any) { setPrefix("INFO") logger.Println(v...) } +func FatalF(fmt string, v ...any) { + setPrefix("FATAL") + logger.Printf(fmt, v...) +} + func Fatal(v ...any) { setPrefix("FATAL") logger.Fatalln(v...)