From 730bd208bb837315c642b5d0f2d7d400f6f0f843 Mon Sep 17 00:00:00 2001 From: lqqyt2423 <974923609@qq.com> Date: Wed, 16 Dec 2020 15:18:25 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- addon/addon.go | 54 +++++++++ flow/flow.go | 45 -------- proxy/helper.go | 89 ++++++++++++++ proxy/interceptor.go | 24 ++++ proxy/middle.go | 213 ++++++++++++++++++++++++++++++++++ proxy/mitm.go | 93 --------------- proxy/mitm_memory.go | 269 ------------------------------------------- proxy/proxy.go | 264 +++++++++++++++--------------------------- proxy/websocket.go | 65 +++++++++++ 9 files changed, 535 insertions(+), 581 deletions(-) create mode 100644 addon/addon.go create mode 100644 proxy/interceptor.go create mode 100644 proxy/middle.go delete mode 100644 proxy/mitm.go delete mode 100644 proxy/mitm_memory.go create mode 100644 proxy/websocket.go diff --git a/addon/addon.go b/addon/addon.go new file mode 100644 index 0000000..d1dd587 --- /dev/null +++ b/addon/addon.go @@ -0,0 +1,54 @@ +package addon + +import ( + "time" + + "github.com/lqqyt2423/go-mitmproxy/flow" + _log "github.com/sirupsen/logrus" +) + +var log = _log.WithField("at", "addon") + +type Addon interface { + // HTTP request headers were successfully read. At this point, the body is empty. + Requestheaders(*flow.Flow) + + // The full HTTP request has been read. + Request(*flow.Flow) + + // HTTP response headers were successfully read. At this point, the body is empty. + Responseheaders(*flow.Flow) + + // The full HTTP response has been read. + Response(*flow.Flow) +} + +// Base do nothing +type Base struct{} + +func (addon *Base) Requestheaders(*flow.Flow) {} +func (addon *Base) Request(*flow.Flow) {} +func (addon *Base) Responseheaders(*flow.Flow) {} +func (addon *Base) Response(*flow.Flow) {} + +// Log log http record +type Log struct { + Base +} + +func (addon *Log) Requestheaders(f *flow.Flow) { + log := log.WithField("in", "Log") + start := time.Now() + go func() { + <-f.Done() + var StatusCode int + if f.Response != nil { + StatusCode = f.Response.StatusCode + } + var contentLen int + if f.Response != nil && f.Response.Body != nil { + contentLen = len(f.Response.Body) + } + log.Infof("%v %v %v %v - %v ms\n", f.Request.Method, f.Request.URL.String(), StatusCode, contentLen, time.Since(start).Milliseconds()) + }() +} diff --git a/flow/flow.go b/flow/flow.go index e5a4fd6..2a099a1 100644 --- a/flow/flow.go +++ b/flow/flow.go @@ -3,7 +3,6 @@ package flow import ( "net/http" "net/url" - "time" _log "github.com/sirupsen/logrus" ) @@ -45,47 +44,3 @@ func (f *Flow) Done() <-chan struct{} { func (f *Flow) Finish() { close(f.done) } - -type Addon interface { - // HTTP request headers were successfully read. At this point, the body is empty. - Requestheaders(*Flow) - - // The full HTTP request has been read. - Request(*Flow) - - // HTTP response headers were successfully read. At this point, the body is empty. - Responseheaders(*Flow) - - // The full HTTP response has been read. - Response(*Flow) -} - -// BaseAddon do nothing -type BaseAddon struct{} - -func (addon *BaseAddon) Requestheaders(*Flow) {} -func (addon *BaseAddon) Request(*Flow) {} -func (addon *BaseAddon) Responseheaders(*Flow) {} -func (addon *BaseAddon) Response(*Flow) {} - -// LogAddon log http record -type LogAddon struct { - BaseAddon -} - -func (addon *LogAddon) Requestheaders(flo *Flow) { - log := log.WithField("in", "LogAddon") - start := time.Now() - go func() { - <-flo.Done() - var StatusCode int - if flo.Response != nil { - StatusCode = flo.Response.StatusCode - } - var contentLen int - if flo.Response != nil && flo.Response.Body != nil { - contentLen = len(flo.Response.Body) - } - log.Infof("%v %v %v %v - %v ms\n", flo.Request.Method, flo.Request.URL.String(), StatusCode, contentLen, time.Since(start).Milliseconds()) - }() -} diff --git a/proxy/helper.go b/proxy/helper.go index 33808cf..495d7ed 100644 --- a/proxy/helper.go +++ b/proxy/helper.go @@ -3,9 +3,76 @@ package proxy import ( "bytes" "io" + "os" + "strings" + "sync" + + _log "github.com/sirupsen/logrus" ) +var NormalErrMsgs []string = []string{ + "read: connection reset by peer", + "write: broken pipe", + "i/o timeout", + "net/http: TLS handshake timeout", + "io: read/write on closed pipe", + "connect: connection refused", + "connect: connection reset by peer", +} + +// 仅打印预料之外的错误信息 +func LogErr(log *_log.Entry, err error) (loged bool) { + msg := err.Error() + + for _, str := range NormalErrMsgs { + if strings.Contains(msg, str) { + log.Debug(err) + return + } + } + + log.Error(err) + loged = true + return +} + +// 转发流量 +// Read a => Write b +// Read b => Write a +func Transfer(log *_log.Entry, a, b io.ReadWriter) { + done := make(chan struct{}) + defer close(done) + + forward := func(dst io.Writer, src io.Reader, ec chan<- error) { + _, err := io.Copy(dst, src) + + if v, ok := dst.(*conn); ok { + // 避免内存泄漏 + _ = v.Writer.CloseWithError(nil) + } + + select { + case <-done: + return + case ec <- err: + } + } + + errChan := make(chan error) + go forward(a, b, errChan) + go forward(b, a, errChan) + + for i := 0; i < 2; i++ { + if err := <-errChan; err != nil { + LogErr(log, err) + return // 如果有错误,直接返回 + } + } +} + // 尝试将 Reader 读取至 buffer 中 +// 如果未达到 limit,则成功读取进入 buffer +// 否则 buffer 返回 nil,且返回新 Reader,状态为未读取前 func ReaderToBuffer(r io.Reader, limit int64) ([]byte, io.Reader, error) { buf := bytes.NewBuffer(make([]byte, 0)) lr := io.LimitReader(r, limit) @@ -24,3 +91,25 @@ func ReaderToBuffer(r io.Reader, limit int64) ([]byte, io.Reader, error) { // 返回 buffer return buf.Bytes(), nil, nil } + +// Wireshark 解析 https 设置 +var tlsKeyLogWriter io.Writer +var tlsKeyLogOnce sync.Once + +func GetTlsKeyLogWriter() io.Writer { + tlsKeyLogOnce.Do(func() { + logfile := os.Getenv("SSLKEYLOGFILE") + if logfile == "" { + return + } + + writer, err := os.OpenFile(logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + log.WithField("in", "GetTlsKeyLogWriter").Debug(err) + return + } + + tlsKeyLogWriter = writer + }) + return tlsKeyLogWriter +} diff --git a/proxy/interceptor.go b/proxy/interceptor.go new file mode 100644 index 0000000..bb546e1 --- /dev/null +++ b/proxy/interceptor.go @@ -0,0 +1,24 @@ +package proxy + +import ( + "net" +) + +// 拦截 https 流量通用接口 +type Interceptor interface { + // 初始化 + Start() error + // 针对每个 host 连接 + Dial(host string) (net.Conn, error) +} + +// 直接转发 https 流量 +type Forward struct{} + +func (i *Forward) Start() error { + return nil +} + +func (i *Forward) Dial(host string) (net.Conn, error) { + return net.Dial("tcp", host) +} diff --git a/proxy/middle.go b/proxy/middle.go new file mode 100644 index 0000000..48b0ca9 --- /dev/null +++ b/proxy/middle.go @@ -0,0 +1,213 @@ +package proxy + +import ( + "bufio" + "crypto/tls" + "net" + "net/http" + "os" + "strings" + "time" + + mock_conn "github.com/jordwest/mock-conn" + "github.com/lqqyt2423/go-mitmproxy/cert" +) + +// 模拟了标准库中 server 运行,目的是仅通过当前进程内存转发 socket 数据,不需要经过 tcp 或 unix socket + +// mock net.Listener +type listener struct { + connChan chan net.Conn +} + +func (l *listener) Accept() (net.Conn, error) { return <-l.connChan, nil } +func (l *listener) Close() error { return nil } +func (l *listener) Addr() net.Addr { return nil } + +type ioRes struct { + n int + err error +} + +// mock net.Conn +type conn struct { + mock_conn.End + host string // remote host + readErrChan chan error // Read 方法提前返回时的错误 +} + +// 建立客户端和服务端通信的通道 +func newPipes(host string) (client *conn, server *connBuf) { + pipes := mock_conn.NewConn() + client = &conn{*pipes.Client, host, nil} + serverConn := &conn{*pipes.Server, host, make(chan error)} + server = newConnBuf(serverConn) + return client, server +} + +// 当接收到 readErrChan 时,可提前返回 +func (c *conn) Read(data []byte) (int, error) { + select { + case err := <-c.readErrChan: + return 0, err + default: + } + + resChan := make(chan *ioRes) + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-done: + return + default: + } + + n, err := c.End.Read(data) + select { + case resChan <- &ioRes{n, err}: + return + case <-done: + close(resChan) + } + }() + + select { + case res := <-resChan: + return res.n, res.err + case err := <-c.readErrChan: + return 0, err + } +} + +func (c *conn) SetDeadline(t time.Time) error { + if !t.Equal(time.Time{}) { + log.WithField("host", c.host).Warnf("SetDeadline %v\n", t) + } + return nil +} + +// http server 会在连接快结束时调用此方法 +func (c *conn) SetReadDeadline(t time.Time) error { + if !t.Equal(time.Time{}) { + if !t.After(time.Now()) { + // 使当前 Read 尽快返回 + c.readErrChan <- os.ErrDeadlineExceeded + } else { + log.WithField("host", c.host).Warnf("SetReadDeadline %v\n", t) + } + } + + return nil +} + +func (c *conn) SetWriteDeadline(t time.Time) error { + log.WithField("host", c.host).Warnf("SetWriteDeadline %v\n", t) + return nil +} + +// add Peek method for conn +type connBuf struct { + *conn + r *bufio.Reader +} + +func newConnBuf(c *conn) *connBuf { + return &connBuf{c, bufio.NewReader(c)} +} + +func (b *connBuf) Peek(n int) ([]byte, error) { + return b.r.Peek(n) +} + +func (b *connBuf) Read(data []byte) (int, error) { + return b.r.Read(data) +} + +// Middle: man-in-the-middle +type Middle struct { + Proxy *Proxy + CA *cert.CA + Listener net.Listener + Server *http.Server +} + +func NewMiddle(proxy *Proxy) (Interceptor, error) { + ca, err := cert.NewCA("") + if err != nil { + return nil, err + } + + m := &Middle{ + Proxy: proxy, + CA: ca, + } + + server := &http.Server{ + Handler: m, + TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 + TLSConfig: &tls.Config{ + GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + log.Debugf("Middle GetCertificate ServerName: %v\n", chi.ServerName) + return ca.GetCert(chi.ServerName) + }, + }, + } + + // 每次连接尽快结束,因为连接并无开销 + server.SetKeepAlivesEnabled(false) + + m.Server = server + + return m, nil +} + +func (m *Middle) Start() error { + m.Listener = &listener{make(chan net.Conn)} + return m.Server.ServeTLS(m.Listener, "", "") +} + +func (m *Middle) Dial(host string) (net.Conn, error) { + clientConn, serverConn := newPipes(host) + go m.intercept(serverConn) + return clientConn, nil +} + +func (m *Middle) ServeHTTP(res http.ResponseWriter, req *http.Request) { + if strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { + // wss + DefaultWebSocket.WSS(res, req) + return + } + + if req.URL.Scheme == "" { + req.URL.Scheme = "https" + } + if req.URL.Host == "" { + req.URL.Host = req.Host + } + m.Proxy.ServeHTTP(res, req) +} + +// 解析 connect 流量 +// 如果是 tls 流量,则进入 listener.Accept => Middle.ServeHTTP +// 否则很可能是 ws 流量 +func (m *Middle) intercept(serverConn *connBuf) { + log := log.WithField("in", "Middle.intercept").WithField("host", serverConn.host) + + buf, err := serverConn.Peek(3) + if err != nil { + log.Errorf("Peek error: %v\n", err) + serverConn.Close() + return + } + + if buf[0] == 0x16 && buf[1] == 0x03 && (buf[2] >= 0x0 || buf[2] <= 0x03) { + // tls + m.Listener.(*listener).connChan <- serverConn + } else { + // ws + DefaultWebSocket.WS(serverConn, serverConn.host) + } +} diff --git a/proxy/mitm.go b/proxy/mitm.go deleted file mode 100644 index e111e35..0000000 --- a/proxy/mitm.go +++ /dev/null @@ -1,93 +0,0 @@ -package proxy - -import ( - "crypto/tls" - "net" - "net/http" - - "github.com/lqqyt2423/go-mitmproxy/cert" -) - -type Mitm interface { - Start() error - Dial(host string) (net.Conn, error) -} - -// 直接转发 https 流量 -type MitmForward struct{} - -func (m *MitmForward) Start() error { - return nil -} - -func (m *MitmForward) Dial(host string) (net.Conn, error) { - return net.Dial("tcp", host) -} - -// 内部解析 https 流量 -// 每个连接都会消耗掉两个文件描述符,可能会达到打开文件上限 -type MitmServer struct { - Proxy *Proxy - CA *cert.CA - Listener net.Listener - Server *http.Server -} - -func NewMitmServer(proxy *Proxy) (Mitm, error) { - ca, err := cert.NewCA("") - if err != nil { - return nil, err - } - - m := &MitmServer{ - Proxy: proxy, - CA: ca, - } - - server := &http.Server{ - Handler: m, - TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 - TLSConfig: &tls.Config{ - GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { - log.Debugf("MitmServer GetCertificate ServerName: %v\n", chi.ServerName) - return ca.GetCert(chi.ServerName) - }, - }, - } - - // 尽快关闭内部的连接,释放文件描述符 - server.SetKeepAlivesEnabled(false) - - m.Server = server - - return m, nil -} - -func (m *MitmServer) Start() error { - ln, err := net.Listen("tcp", "127.0.0.1:") // port number is automatically chosen - if err != nil { - return err - } - m.Listener = ln - m.Server.Addr = ln.Addr().String() - log.Infof("MitmServer Server Addr is %v\n", m.Server.Addr) - defer ln.Close() - - return m.Server.ServeTLS(ln, "", "") -} - -func (m *MitmServer) Dial(host string) (net.Conn, error) { - return net.Dial("tcp", m.Server.Addr) -} - -func (m *MitmServer) ServeHTTP(res http.ResponseWriter, req *http.Request) { - if req.URL.Scheme == "" { - req.URL.Scheme = "https" - } - - if req.URL.Host == "" { - req.URL.Host = req.Host - } - - m.Proxy.ServeHTTP(res, req) -} diff --git a/proxy/mitm_memory.go b/proxy/mitm_memory.go deleted file mode 100644 index 8a81c39..0000000 --- a/proxy/mitm_memory.go +++ /dev/null @@ -1,269 +0,0 @@ -package proxy - -import ( - "bufio" - "crypto/tls" - "net" - "net/http" - "net/http/httputil" - "os" - "strings" - "time" - - mock_conn "github.com/jordwest/mock-conn" - "github.com/lqqyt2423/go-mitmproxy/cert" -) - -// 模拟实现 net - -type listener struct { - connChan chan net.Conn -} - -func (l *listener) Accept() (net.Conn, error) { - return <-l.connChan, nil -} - -func (l *listener) Close() error { - return nil -} - -func (l *listener) Addr() net.Addr { - return nil -} - -type ioRes struct { - n int - err error -} - -type conn struct { - *mock_conn.End - - Host string // remote host - readErrChan chan error // Read 方法提前返回时的错误 -} - -func newConn(end *mock_conn.End, host string) *conn { - return &conn{ - End: end, - Host: host, - readErrChan: make(chan error), - } -} - -// 当接收到 readErrChan 时,可提前返回 -func (c *conn) Read(data []byte) (int, error) { - select { - case err := <-c.readErrChan: - return 0, err - default: - } - - resChan := make(chan *ioRes) - done := make(chan struct{}) - defer close(done) - - go func() { - select { - case <-done: - return - default: - } - - n, err := c.End.Read(data) - select { - case resChan <- &ioRes{n, err}: - return - case <-done: - close(resChan) - } - }() - - select { - case res := <-resChan: - return res.n, res.err - case err := <-c.readErrChan: - return 0, err - } -} - -func (c *conn) SetDeadline(t time.Time) error { - if !t.Equal(time.Time{}) { - log.WithField("host", c.Host).Warnf("SetDeadline %v\n", t) - } - return nil -} - -// http server 会在连接快结束时调用此方法 -func (c *conn) SetReadDeadline(t time.Time) error { - if !t.Equal(time.Time{}) { - if !t.After(time.Now()) { - // 使当前 Read 尽快返回 - c.readErrChan <- os.ErrDeadlineExceeded - } else { - log.Warnf("SetReadDeadline %v\n", t) - } - } - - return nil -} - -func (c *conn) SetWriteDeadline(t time.Time) error { - log.WithField("host", c.Host).Warnf("SetWriteDeadline %v\n", t) - return nil -} - -// wrap conn for peek -type connBuf struct { - *conn - r *bufio.Reader -} - -func newConnBuf(c *conn) *connBuf { - return &connBuf{c, bufio.NewReader(c)} -} - -func (b *connBuf) Peek(n int) ([]byte, error) { - return b.r.Peek(n) -} - -func (b *connBuf) Read(data []byte) (int, error) { - return b.r.Read(data) -} - -type MitmMemory struct { - Proxy *Proxy - CA *cert.CA - Listener net.Listener - Server *http.Server -} - -func NewMitmMemory(proxy *Proxy) (Mitm, error) { - ca, err := cert.NewCA("") - if err != nil { - return nil, err - } - - m := &MitmMemory{ - Proxy: proxy, - CA: ca, - } - - server := &http.Server{ - Handler: m, - TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 - TLSConfig: &tls.Config{ - GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { - log.Debugf("MitmMemory GetCertificate ServerName: %v\n", chi.ServerName) - return ca.GetCert(chi.ServerName) - }, - }, - } - - // 每次连接尽快结束,因为连接并无开销 - server.SetKeepAlivesEnabled(false) - - m.Server = server - - return m, nil -} - -func (m *MitmMemory) Start() error { - ln := &listener{ - connChan: make(chan net.Conn), - } - m.Listener = ln - return m.Server.ServeTLS(ln, "", "") -} - -func (m *MitmMemory) Dial(host string) (net.Conn, error) { - log := log.WithField("in", "MitmMemory.Dial").WithField("host", host) - pipes := mock_conn.NewConn() - - // 如果是 tls 流量,则进入 listener.Accept => MitmMemory.ServeHTTP - // 否则很可能是 ws 流量,直接转发流量 - go func() { - conn := newConn(pipes.Server, host) - connb := newConnBuf(conn) - buf, err := connb.Peek(3) - if err != nil { - log.Errorf("Peek error: %v\n", err) - connb.Close() - return - } - - // tls - if buf[0] == 0x16 && buf[1] == 0x03 && (buf[2] >= 0x0 || buf[2] <= 0x03) { - m.Listener.(*listener).connChan <- connb - } else { - // websocket ws:// - log.Debug("begin websocket ws://") - defer connb.Close() - remoteConn, err := net.Dial("tcp", host) - if err != nil { - if !ignoreErr(log, err) { - log.Error(err) - } - return - } - defer remoteConn.Close() - transfer(log, connb, remoteConn) - } - }() - - return newConn(pipes.Client, host), nil -} - -func (m *MitmMemory) ServeHTTP(res http.ResponseWriter, req *http.Request) { - log := log.WithField("in", "MitmMemory.ServeHTTP").WithField("host", req.Host) - - // websocket wss:// - if strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { - log.Debug("begin websocket wss://") - - upgradeBuf, err := httputil.DumpRequest(req, false) - if err != nil { - log.Errorf("DumpRequest: %v\n", err) - res.WriteHeader(502) - return - } - - cconn, _, err := res.(http.Hijacker).Hijack() - if err != nil { - log.Errorf("Hijack: %v\n", err) - res.WriteHeader(502) - return - } - defer cconn.Close() - - host := req.Host - if !strings.Contains(host, ":") { - host = host + ":443" - } - conn, err := tls.Dial("tcp", host, nil) - if err != nil { - log.Errorf("tls.Dial: %v\n", err) - return - } - defer conn.Close() - - _, err = conn.Write(upgradeBuf) - if err != nil { - log.Errorf("wss upgrade: %v\n", err) - return - } - transfer(log, conn, cconn) - return - } - - if req.URL.Scheme == "" { - req.URL.Scheme = "https" - } - - if req.URL.Host == "" { - req.URL.Host = req.Host - } - - m.Proxy.ServeHTTP(res, req) -} diff --git a/proxy/proxy.go b/proxy/proxy.go index 2700467..7dc43b9 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -6,86 +6,80 @@ import ( "io" "net" "net/http" - "os" - "strings" - "sync" "time" + "github.com/lqqyt2423/go-mitmproxy/addon" "github.com/lqqyt2423/go-mitmproxy/flow" _log "github.com/sirupsen/logrus" ) var log = _log.WithField("at", "proxy") -var ignoreErr = func(log *_log.Entry, err error) bool { - errs := err.Error() - strs := []string{ - "read: connection reset by peer", - "write: broken pipe", - "i/o timeout", - "net/http: TLS handshake timeout", - "io: read/write on closed pipe", - "connect: connection refused", - "connect: connection reset by peer", - } - - for _, str := range strs { - if strings.Contains(errs, str) { - log.Debug(err) - return true - } - } +type Options struct { + Addr string + StreamLargeBodies int64 +} - return false +type Proxy struct { + Server *http.Server + Client *http.Client + Interceptor Interceptor + StreamLargeBodies int64 // 当请求或响应体大于此字节时,转为 stream 模式 + Addons []addon.Addon } -func transfer(log *_log.Entry, a, b io.ReadWriter) { - done := make(chan struct{}) - defer close(done) +func NewProxy(opts *Options) (*Proxy, error) { + proxy := new(Proxy) - forward := func(dst io.Writer, src io.Reader, ec chan<- error) { - _, err := io.Copy(dst, src) + proxy.Server = &http.Server{ + Addr: opts.Addr, + Handler: proxy, + } - if v, ok := dst.(*conn); ok { - // 避免内存泄漏的关键 - _ = v.Writer.CloseWithError(nil) - } + proxy.Client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, - select { - case <-done: - return - case ec <- err: - } + ForceAttemptHTTP2: false, // disable http2 + DisableCompression: true, + TLSClientConfig: &tls.Config{ + KeyLogWriter: GetTlsKeyLogWriter(), + }, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // 禁止自动重定向 + return http.ErrUseLastResponse + }, } - errChan := make(chan error) - go forward(a, b, errChan) - go forward(b, a, errChan) + interceptor, err := NewMiddle(proxy) + if err != nil { + return nil, err + } + proxy.Interceptor = interceptor - for i := 0; i < 2; i++ { - if err := <-errChan; err != nil { - if !ignoreErr(log, err) { - log.Error(err) - } - return // 如果有错误,直接返回 - } + if opts.StreamLargeBodies > 0 { + proxy.StreamLargeBodies = opts.StreamLargeBodies + } else { + proxy.StreamLargeBodies = 1024 * 1024 * 5 // default: 5mb } -} -type Options struct { - Addr string - StreamLargeBodies int64 -} + proxy.Addons = make([]addon.Addon, 0) + proxy.AddAddon(&addon.Log{}) -type Proxy struct { - Server *http.Server - Client *http.Client - Mitm Mitm - StreamLargeBodies int64 // 当请求或响应体大于此字节时,转为 stream 模式 - Addons []flow.Addon + return proxy, nil } -func (proxy *Proxy) AddAddon(addon flow.Addon) { +func (proxy *Proxy) AddAddon(addon addon.Addon) { proxy.Addons = append(proxy.Addons, addon) } @@ -99,7 +93,7 @@ func (proxy *Proxy) Start() error { }() go func() { - err := proxy.Mitm.Start() + err := proxy.Interceptor.Start() errChan <- err }() @@ -130,7 +124,7 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { return } - endRes := func(response *flow.Response, body io.Reader) { + reply := func(response *flow.Response, body io.Reader) { if response.Header != nil { for key, value := range response.Header { for _, v := range value { @@ -142,13 +136,13 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { if body != nil { _, err := io.Copy(res, body) - if err != nil && !ignoreErr(log, err) { - log.Error(err) + if err != nil { + LogErr(log, err) } } else if response.Body != nil && len(response.Body) > 0 { _, err := res.Write(response.Body) - if err != nil && !ignoreErr(log, err) { - log.Error(err) + if err != nil { + LogErr(log, err) } } } @@ -160,27 +154,27 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { } }() - flo := flow.NewFlow() - flo.Request = &flow.Request{ + f := flow.NewFlow() + f.Request = &flow.Request{ Method: req.Method, URL: req.URL, Proto: req.Proto, Header: req.Header, } - defer flo.Finish() + defer f.Finish() // trigger addon event Requestheaders for _, addon := range proxy.Addons { - addon.Requestheaders(flo) - if flo.Response != nil { - endRes(flo.Response, nil) + addon.Requestheaders(f) + if f.Response != nil { + reply(f.Response, nil) return } } - // 读 request body + // Read request body var reqBody io.Reader = req.Body - if !flo.Stream { + if !f.Stream { reqBuf, r, err := ReaderToBuffer(req.Body, proxy.StreamLargeBodies) reqBody = r if err != nil { @@ -188,65 +182,62 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { res.WriteHeader(502) return } + if reqBuf == nil { log.Warnf("request body size >= %v\n", proxy.StreamLargeBodies) - flo.Stream = true + f.Stream = true } else { - flo.Request.Body = reqBuf - } + f.Request.Body = reqBuf - // trigger addon event Request - if !flo.Stream { + // trigger addon event Request for _, addon := range proxy.Addons { - addon.Request(flo) - if flo.Response != nil { - endRes(flo.Response, nil) + addon.Request(f) + if f.Response != nil { + reply(f.Response, nil) return } } - reqBody = bytes.NewReader(flo.Request.Body) + reqBody = bytes.NewReader(f.Request.Body) } } - proxyReq, err := http.NewRequest(flo.Request.Method, flo.Request.URL.String(), reqBody) + proxyReq, err := http.NewRequest(f.Request.Method, f.Request.URL.String(), reqBody) if err != nil { log.Error(err) res.WriteHeader(502) return } - for key, value := range flo.Request.Header { + for key, value := range f.Request.Header { for _, v := range value { proxyReq.Header.Add(key, v) } } proxyRes, err := proxy.Client.Do(proxyReq) if err != nil { - if !ignoreErr(log, err) { - log.Error(err) - } + LogErr(log, err) res.WriteHeader(502) return } defer proxyRes.Body.Close() - flo.Response = &flow.Response{ + f.Response = &flow.Response{ StatusCode: proxyRes.StatusCode, Header: proxyRes.Header, } // trigger addon event Responseheaders for _, addon := range proxy.Addons { - addon.Responseheaders(flo) - if flo.Response.Body != nil { - endRes(flo.Response, nil) + addon.Responseheaders(f) + if f.Response.Body != nil { + reply(f.Response, nil) return } } - // 读 response body + // Read response body var resBody io.Reader = proxyRes.Body - if !flo.Stream { + if !f.Stream { resBuf, r, err := ReaderToBuffer(proxyRes.Body, proxy.StreamLargeBodies) resBody = r if err != nil { @@ -256,20 +247,18 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { } if resBuf == nil { log.Warnf("response body size >= %v\n", proxy.StreamLargeBodies) - flo.Stream = true + f.Stream = true } else { - flo.Response.Body = resBuf - } + f.Response.Body = resBuf - // trigger addon event Response - if !flo.Stream { + // trigger addon event Response for _, addon := range proxy.Addons { - addon.Response(flo) + addon.Response(f) } } } - endRes(flo.Response, resBody) + reply(f.Response, resBody) } func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { @@ -280,8 +269,7 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { log.Debug("receive connect") - conn, err := proxy.Mitm.Dial(req.Host) - + conn, err := proxy.Interceptor.Dial(req.Host) if err != nil { log.Error(err) res.WriteHeader(502) @@ -303,77 +291,5 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { return } - transfer(log, conn, cconn) -} - -func NewProxy(opts *Options) (*Proxy, error) { - proxy := new(Proxy) - proxy.Server = &http.Server{ - Addr: opts.Addr, - Handler: proxy, - } - - proxy.Client = &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - - ForceAttemptHTTP2: false, // disable http2 - DisableCompression: true, - TLSClientConfig: &tls.Config{ - KeyLogWriter: GetTlsKeyLogWriter(), - }, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - // 禁止自动重定向 - return http.ErrUseLastResponse - }, - } - - mitm, err := NewMitmMemory(proxy) - if err != nil { - return nil, err - } - - proxy.Mitm = mitm - - if opts.StreamLargeBodies > 0 { - proxy.StreamLargeBodies = opts.StreamLargeBodies - } else { - proxy.StreamLargeBodies = 1024 * 1024 * 5 // default: 5mb - } - proxy.Addons = make([]flow.Addon, 0) - proxy.AddAddon(&flow.LogAddon{}) - - return proxy, nil -} - -var tlsKeyLogWriter io.Writer -var tlsKeyLogOnce sync.Once - -// Wireshark 解析 https 设置 -func GetTlsKeyLogWriter() io.Writer { - tlsKeyLogOnce.Do(func() { - logfile := os.Getenv("SSLKEYLOGFILE") - if logfile == "" { - return - } - - writer, err := os.OpenFile(logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) - if err != nil { - log.WithField("in", "GetTlsKeyLogWriter").Debug(err) - return - } - - tlsKeyLogWriter = writer - }) - return tlsKeyLogWriter + Transfer(log, conn, cconn) } diff --git a/proxy/websocket.go b/proxy/websocket.go new file mode 100644 index 0000000..a03d3b3 --- /dev/null +++ b/proxy/websocket.go @@ -0,0 +1,65 @@ +package proxy + +import ( + "crypto/tls" + "net" + "net/http" + "net/http/httputil" + "strings" +) + +// 当前仅做了转发 websocket 流量 + +type WebSocket struct{} + +var DefaultWebSocket WebSocket + +func (s *WebSocket) WS(conn net.Conn, host string) { + log := log.WithField("in", "WebSocket.WS").WithField("host", host) + + defer conn.Close() + remoteConn, err := net.Dial("tcp", host) + if err != nil { + LogErr(log, err) + return + } + defer remoteConn.Close() + Transfer(log, conn, remoteConn) +} + +func (s *WebSocket) WSS(res http.ResponseWriter, req *http.Request) { + log := log.WithField("in", "WebSocket.WSS").WithField("host", req.Host) + + upgradeBuf, err := httputil.DumpRequest(req, false) + if err != nil { + log.Errorf("DumpRequest: %v\n", err) + res.WriteHeader(502) + return + } + + cconn, _, err := res.(http.Hijacker).Hijack() + if err != nil { + log.Errorf("Hijack: %v\n", err) + res.WriteHeader(502) + return + } + defer cconn.Close() + + host := req.Host + if !strings.Contains(host, ":") { + host = host + ":443" + } + conn, err := tls.Dial("tcp", host, nil) + if err != nil { + log.Errorf("tls.Dial: %v\n", err) + return + } + defer conn.Close() + + _, err = conn.Write(upgradeBuf) + if err != nil { + log.Errorf("wss upgrade: %v\n", err) + return + } + Transfer(log, conn, cconn) +}