diff --git a/proxy/mitm_memory.go b/proxy/mitm_memory.go index 7979c91..8a81c39 100644 --- a/proxy/mitm_memory.go +++ b/proxy/mitm_memory.go @@ -1,10 +1,13 @@ package proxy import ( + "bufio" "crypto/tls" "net" "net/http" + "net/http/httputil" "os" + "strings" "time" mock_conn "github.com/jordwest/mock-conn" @@ -37,12 +40,14 @@ type ioRes struct { type conn struct { *mock_conn.End + Host string // remote host readErrChan chan error // Read 方法提前返回时的错误 } -func newConn(end *mock_conn.End) *conn { +func newConn(end *mock_conn.End, host string) *conn { return &conn{ End: end, + Host: host, readErrChan: make(chan error), } } @@ -84,7 +89,9 @@ func (c *conn) Read(data []byte) (int, error) { } func (c *conn) SetDeadline(t time.Time) error { - log.Warnf("SetDeadline %v\n", t) + if !t.Equal(time.Time{}) { + log.WithField("host", c.Host).Warnf("SetDeadline %v\n", t) + } return nil } @@ -103,10 +110,28 @@ func (c *conn) SetReadDeadline(t time.Time) error { } func (c *conn) SetWriteDeadline(t time.Time) error { - log.Warnf("SetWriteDeadline %v\n", t) + log.WithField("host", c.Host).Warnf("SetWriteDeadline %v\n", t) return nil } +// wrap conn for peek +type connBuf struct { + *conn + r *bufio.Reader +} + +func newConnBuf(c *conn) *connBuf { + return &connBuf{c, bufio.NewReader(c)} +} + +func (b *connBuf) Peek(n int) ([]byte, error) { + return b.r.Peek(n) +} + +func (b *connBuf) Read(data []byte) (int, error) { + return b.r.Read(data) +} + type MitmMemory struct { Proxy *Proxy CA *cert.CA @@ -153,12 +178,85 @@ func (m *MitmMemory) Start() error { } func (m *MitmMemory) Dial(host string) (net.Conn, error) { + log := log.WithField("in", "MitmMemory.Dial").WithField("host", host) pipes := mock_conn.NewConn() - m.Listener.(*listener).connChan <- newConn(pipes.Server) - return newConn(pipes.Client), nil + + // 如果是 tls 流量,则进入 listener.Accept => MitmMemory.ServeHTTP + // 否则很可能是 ws 流量,直接转发流量 + go func() { + conn := newConn(pipes.Server, host) + connb := newConnBuf(conn) + buf, err := connb.Peek(3) + if err != nil { + log.Errorf("Peek error: %v\n", err) + connb.Close() + return + } + + // tls + if buf[0] == 0x16 && buf[1] == 0x03 && (buf[2] >= 0x0 || buf[2] <= 0x03) { + m.Listener.(*listener).connChan <- connb + } else { + // websocket ws:// + log.Debug("begin websocket ws://") + defer connb.Close() + remoteConn, err := net.Dial("tcp", host) + if err != nil { + if !ignoreErr(log, err) { + log.Error(err) + } + return + } + defer remoteConn.Close() + transfer(log, connb, remoteConn) + } + }() + + return newConn(pipes.Client, host), nil } func (m *MitmMemory) ServeHTTP(res http.ResponseWriter, req *http.Request) { + log := log.WithField("in", "MitmMemory.ServeHTTP").WithField("host", req.Host) + + // websocket wss:// + if strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { + log.Debug("begin websocket wss://") + + upgradeBuf, err := httputil.DumpRequest(req, false) + if err != nil { + log.Errorf("DumpRequest: %v\n", err) + res.WriteHeader(502) + return + } + + cconn, _, err := res.(http.Hijacker).Hijack() + if err != nil { + log.Errorf("Hijack: %v\n", err) + res.WriteHeader(502) + return + } + defer cconn.Close() + + host := req.Host + if !strings.Contains(host, ":") { + host = host + ":443" + } + conn, err := tls.Dial("tcp", host, nil) + if err != nil { + log.Errorf("tls.Dial: %v\n", err) + return + } + defer conn.Close() + + _, err = conn.Write(upgradeBuf) + if err != nil { + log.Errorf("wss upgrade: %v\n", err) + return + } + transfer(log, conn, cconn) + return + } + if req.URL.Scheme == "" { req.URL.Scheme = "https" } diff --git a/proxy/proxy.go b/proxy/proxy.go index 7b9c2c2..f3a3fd7 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -27,7 +27,7 @@ var ignoreErr = func(log *_log.Entry, err error) bool { for _, str := range strs { if strings.Contains(errs, str) { - log.Debug(str) + log.Debug(err) return true } } @@ -35,6 +35,23 @@ var ignoreErr = func(log *_log.Entry, err error) bool { return false } +func transfer(log *_log.Entry, a, b io.ReadWriter) { + done := make(chan struct{}) + go func() { + _, err := io.Copy(a, b) + if err != nil && !ignoreErr(log, err) { + log.Error(err) + } + close(done) + }() + + _, err := io.Copy(b, a) + if err != nil && !ignoreErr(log, err) { + log.Error(err) + } + <-done +} + type Options struct { Addr string } @@ -70,11 +87,13 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { } log := log.WithFields(_log.Fields{ - "in": "ServeHTTP", + "in": "Proxy.ServeHTTP", "url": req.URL, "method": req.Method, }) + log.Debug("receive request") + if !req.URL.IsAbs() || req.URL.Host == "" { res.WriteHeader(400) _, err := io.WriteString(res, "此为代理服务器,不能直接发起请求") @@ -125,11 +144,11 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { log := log.WithFields(_log.Fields{ - "in": "handleConnect", + "in": "Proxy.handleConnect", "host": req.Host, }) - log.Debug("CONNECT") + log.Debug("receive connect") conn, err := proxy.Mitm.Dial(req.Host) @@ -154,21 +173,7 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { return } - done := make(chan struct{}) - go func() { - _, err := io.Copy(conn, cconn) - if err != nil && !ignoreErr(log, err) { - log.Error(err) - } - close(done) - }() - - _, err = io.Copy(cconn, conn) - if err != nil && !ignoreErr(log, err) { - log.Error(err) - } - - <-done + transfer(log, conn, cconn) } func NewProxy(opts *Options) (*Proxy, error) {