diff --git a/proxy/middle.go b/proxy/middle.go index 48b0ca9..84cd367 100644 --- a/proxy/middle.go +++ b/proxy/middle.go @@ -3,6 +3,7 @@ package proxy import ( "bufio" "crypto/tls" + "errors" "net" "net/http" "os" @@ -24,87 +25,110 @@ func (l *listener) Accept() (net.Conn, error) { return <-l.connChan, nil } func (l *listener) Close() error { return nil } func (l *listener) Addr() net.Addr { return nil } -type ioRes struct { - n int - err error -} - // mock net.Conn type conn struct { mock_conn.End - host string // remote host - readErrChan chan error // Read 方法提前返回时的错误 + host string // remote host + + // 以下为实现 SetReadDeadline 所需字段:需要确保 Read 方法可以提前返回 + // connection: keep-alive 相关 + readCanCancel bool // 是否可取消 Read + firstRead bool // 首次调用 Read 初始化 + pendingRead bool // 当前是否有 Read 操作在阻塞中 + readErrChan chan error // Read 方法提前返回时的错误,总是 os.ErrDeadlineExceeded + readErr error // 底层 End 返回的错误 + readDeadline time.Time // SetReadDeadline 设置的时间 + chunk chan []byte // Read 和 beginRead 的交互 channel } +var connUnexpected = errors.New("unexpected read error") + // 建立客户端和服务端通信的通道 func newPipes(host string) (client *conn, server *connBuf) { pipes := mock_conn.NewConn() - client = &conn{*pipes.Client, host, nil} - serverConn := &conn{*pipes.Server, host, make(chan error)} + client = &conn{ + End: *pipes.Client, + host: host, + } + serverConn := &conn{ + End: *pipes.Server, + host: host, + readCanCancel: true, + readErrChan: make(chan error), + chunk: make(chan []byte), + } server = newConnBuf(serverConn) return client, server } -// 当接收到 readErrChan 时,可提前返回 +func (c *conn) beginRead(size int) { + buf := make([]byte, size) + for { + n, err := c.End.Read(buf) + if err != nil { + c.readErr = err + close(c.chunk) + return + } + chunk := make([]byte, n) + copy(chunk, buf[:n]) + c.chunk <- chunk + } +} + func (c *conn) Read(data []byte) (int, error) { - select { - case err := <-c.readErrChan: - return 0, err - default: + if !c.readCanCancel { + return c.End.Read(data) } - resChan := make(chan *ioRes) - done := make(chan struct{}) - defer close(done) + if !c.firstRead { + go c.beginRead(len(data)) + } + c.firstRead = true - go func() { - select { - case <-done: - return - default: + if !c.readDeadline.Equal(time.Time{}) { + if !c.readDeadline.After(time.Now()) { + return 0, os.ErrDeadlineExceeded + } else { + log.WithField("host", c.host).Warnf("c.readDeadline is future %v\n", c.readDeadline) + return 0, connUnexpected } + } - n, err := c.End.Read(data) - select { - case resChan <- &ioRes{n, err}: - return - case <-done: - close(resChan) - } + c.pendingRead = true + defer func() { + c.pendingRead = false }() select { - case res := <-resChan: - return res.n, res.err case err := <-c.readErrChan: return 0, err + case chunk, ok := <-c.chunk: + if !ok { + return 0, c.readErr + } + copy(data, chunk) + return len(chunk), nil } } func (c *conn) SetDeadline(t time.Time) error { - if !t.Equal(time.Time{}) { - log.WithField("host", c.host).Warnf("SetDeadline %v\n", t) - } - return nil + log.WithField("host", c.host).Warnf("SetDeadline %v\n", t) + return connUnexpected } -// http server 会在连接快结束时调用此方法 +// http server 标准库实现时,当多个 http 复用底层 socke 时,会调用此方法 func (c *conn) SetReadDeadline(t time.Time) error { - if !t.Equal(time.Time{}) { - if !t.After(time.Now()) { - // 使当前 Read 尽快返回 - c.readErrChan <- os.ErrDeadlineExceeded - } else { - log.WithField("host", c.host).Warnf("SetReadDeadline %v\n", t) - } + c.readDeadline = t + if c.pendingRead && !t.Equal(time.Time{}) && !t.After(time.Now()) { + c.readErrChan <- os.ErrDeadlineExceeded } - return nil } func (c *conn) SetWriteDeadline(t time.Time) error { log.WithField("host", c.host).Warnf("SetWriteDeadline %v\n", t) - return nil + return connUnexpected } // add Peek method for conn @@ -155,9 +179,6 @@ func NewMiddle(proxy *Proxy) (Interceptor, error) { }, } - // 每次连接尽快结束,因为连接并无开销 - server.SetKeepAlivesEnabled(false) - m.Server = server return m, nil