diff --git a/go.mod b/go.mod index fc97630..3b59357 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/andybalholm/brotli v1.0.1 github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e github.com/gorilla/websocket v1.4.2 - github.com/jordwest/mock-conn v0.0.0-20180617021051-4896c6bd1641 github.com/satori/go.uuid v1.2.0 github.com/sirupsen/logrus v1.7.0 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index 22d4e6a..744126a 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,6 @@ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18h github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/jordwest/mock-conn v0.0.0-20180617021051-4896c6bd1641 h1:ChkB2s4mFDekyUUmbNE7qNhennP0rfqF2YZUOGxbhFk= -github.com/jordwest/mock-conn v0.0.0-20180617021051-4896c6bd1641/go.mod h1:AJFEOPtj5Z5z3MAy+0uvjQAH02iRnQr6fnvuHYp/Jek= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= diff --git a/proxy/helper.go b/proxy/helper.go index 495d7ed..f659bec 100644 --- a/proxy/helper.go +++ b/proxy/helper.go @@ -39,22 +39,27 @@ func LogErr(log *_log.Entry, err error) (loged bool) { // 转发流量 // Read a => Write b // Read b => Write a -func Transfer(log *_log.Entry, a, b io.ReadWriter) { +func Transfer(log *_log.Entry, a, b io.ReadWriteCloser) { done := make(chan struct{}) defer close(done) - forward := func(dst io.Writer, src io.Reader, ec chan<- error) { + forward := func(dst io.WriteCloser, src io.Reader, ec chan<- error) { _, err := io.Copy(dst, src) - - if v, ok := dst.(*conn); ok { - // 避免内存泄漏 - _ = v.Writer.CloseWithError(nil) + if err != nil { + select { + case <-done: + return + case ec <- err: + return + } } + err = dst.Close() select { case <-done: return case ec <- err: + return } } diff --git a/proxy/middle.go b/proxy/middle.go index c373925..55984ef 100644 --- a/proxy/middle.go +++ b/proxy/middle.go @@ -3,15 +3,10 @@ package proxy import ( "bufio" "crypto/tls" - "errors" "net" "net/http" - "os" "strings" - "sync" - "time" - mock_conn "github.com/jordwest/mock-conn" "github.com/lqqyt2423/go-mitmproxy/cert" ) @@ -26,123 +21,26 @@ func (l *listener) Accept() (net.Conn, error) { return <-l.connChan, nil } func (l *listener) Close() error { return nil } func (l *listener) Addr() net.Addr { return nil } -// mock net.Conn -type conn struct { - mock_conn.End - host string // remote host - - // 以下为实现 SetReadDeadline 所需字段:需要确保 Read 方法可以提前返回 - // connection: keep-alive 相关 - readCanCancel bool // 是否可取消 Read - firstRead bool // 首次调用 Read 初始化 - readErrChan chan error // Read 方法提前返回时的错误,总是 os.ErrDeadlineExceeded - readErr error // 底层 End 返回的错误 - readDeadline time.Time // SetReadDeadline 设置的时间 - chunk chan []byte // Read 和 beginRead 的交互 channel - - readDeadlineMu sync.RWMutex -} - -var connUnexpected = errors.New("unexpected read error") - // 建立客户端和服务端通信的通道 -func newPipes(host string) (client *conn, server *connBuf) { - pipes := mock_conn.NewConn() - client = &conn{ - End: *pipes.Client, - host: host, - } - serverConn := &conn{ - End: *pipes.Server, - host: host, - readCanCancel: true, - readErrChan: make(chan error), - chunk: make(chan []byte), - } - server = newConnBuf(serverConn) +func newPipes(host string) (net.Conn, *connBuf) { + client, srv := net.Pipe() + server := newConnBuf(srv, host) return client, server } -func (c *conn) beginRead(size int) { - buf := make([]byte, size) - for { - n, err := c.End.Read(buf) - if err != nil { - c.readErr = err - close(c.chunk) - return - } - chunk := make([]byte, n) - copy(chunk, buf[:n]) - c.chunk <- chunk - } -} - -func (c *conn) Read(data []byte) (int, error) { - if !c.readCanCancel { - return c.End.Read(data) - } - - if !c.firstRead { - go c.beginRead(len(data)) - } - c.firstRead = true - - c.readDeadlineMu.RLock() - if !c.readDeadline.Equal(time.Time{}) { - if !c.readDeadline.After(time.Now()) { - c.readDeadlineMu.RUnlock() - return 0, os.ErrDeadlineExceeded - } else { - c.readDeadlineMu.RUnlock() - log.WithField("host", c.host).Warnf("c.readDeadline is future %v\n", c.readDeadline) - return 0, connUnexpected - } - } - c.readDeadlineMu.RUnlock() - - select { - case err := <-c.readErrChan: - return 0, err - case chunk, ok := <-c.chunk: - if !ok { - return 0, c.readErr - } - copy(data, chunk) - return len(chunk), nil - } -} - -func (c *conn) SetDeadline(t time.Time) error { - log.WithField("host", c.host).Warnf("SetDeadline %v\n", t) - return connUnexpected -} - -// http server 标准库实现时,当多个 http 复用底层 socket 时,会调用此方法 -func (c *conn) SetReadDeadline(t time.Time) error { - c.readDeadlineMu.Lock() - c.readDeadline = t - c.readDeadlineMu.Unlock() - - if !t.Equal(time.Time{}) && !t.After(time.Now()) { - c.readErrChan <- os.ErrDeadlineExceeded - } - return nil -} - -func (c *conn) SetWriteDeadline(t time.Time) error { - log.WithField("host", c.host).Warnf("SetWriteDeadline %v\n", t) - return connUnexpected -} - // add Peek method for conn type connBuf struct { - *conn - r *bufio.Reader + net.Conn + r *bufio.Reader + host string } -func newConnBuf(c *conn) *connBuf { - return &connBuf{c, bufio.NewReader(c)} +func newConnBuf(c net.Conn, host string) *connBuf { + return &connBuf{ + Conn: c, + r: bufio.NewReader(c), + host: host, + } } func (b *connBuf) Peek(n int) ([]byte, error) {