diff --git a/connection/connection.go b/connection/connection.go index ac6eff3..ec29189 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -2,6 +2,7 @@ package connection import ( "net" + "net/http" uuid "github.com/satori/go.uuid" ) @@ -19,3 +20,14 @@ func NewClient(c net.Conn) *Client { Tls: false, } } + +type Server struct { + Id uuid.UUID + Client *http.Client +} + +func NewServer() *Server { + return &Server{ + Id: uuid.NewV4(), + } +} diff --git a/flow/conncontext.go b/flow/conncontext.go index 175c8dc..b9f3819 100644 --- a/flow/conncontext.go +++ b/flow/conncontext.go @@ -1,9 +1,85 @@ package flow -import "github.com/lqqyt2423/go-mitmproxy/connection" +import ( + "crypto/tls" + "net" + "net/http" + "time" + + "github.com/lqqyt2423/go-mitmproxy/connection" +) + +var ConnContextKey = new(struct{}) type ConnContext struct { Client *connection.Client + Server *connection.Server } -var ConnContextKey = new(struct{}) +func (connCtx *ConnContext) InitHttpServer(SslInsecure bool) { + if connCtx.Server != nil { + return + } + if connCtx.Client.Tls { + return + } + + server := connection.NewServer() + server.Client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + + // todo: change here + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: false, // disable http2 + + DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true. + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: SslInsecure, + KeyLogWriter: GetTlsKeyLogWriter(), + }, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // 禁止自动重定向 + return http.ErrUseLastResponse + }, + } + connCtx.Server = server +} + +func (connCtx *ConnContext) InitHttpsServer(SslInsecure bool) { + if connCtx.Server != nil { + return + } + if !connCtx.Client.Tls { + return + } + + server := connection.NewServer() + server.Client = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + + // todo: change here + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: false, // disable http2 + + DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true. + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: SslInsecure, + KeyLogWriter: GetTlsKeyLogWriter(), + }, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // 禁止自动重定向 + return http.ErrUseLastResponse + }, + } + connCtx.Server = server +} diff --git a/flow/helper.go b/flow/helper.go new file mode 100644 index 0000000..6dee72d --- /dev/null +++ b/flow/helper.go @@ -0,0 +1,29 @@ +package flow + +import ( + "io" + "os" + "sync" +) + +// Wireshark 解析 https 设置 +var tlsKeyLogWriter io.Writer +var tlsKeyLogOnce sync.Once + +func GetTlsKeyLogWriter() io.Writer { + tlsKeyLogOnce.Do(func() { + logfile := os.Getenv("SSLKEYLOGFILE") + if logfile == "" { + return + } + + writer, err := os.OpenFile(logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + log.WithField("in", "GetTlsKeyLogWriter").Debug(err) + return + } + + tlsKeyLogWriter = writer + }) + return tlsKeyLogWriter +} diff --git a/proxy/helper.go b/proxy/helper.go index ffe8d82..4778404 100644 --- a/proxy/helper.go +++ b/proxy/helper.go @@ -3,9 +3,7 @@ package proxy import ( "bytes" "io" - "os" "strings" - "sync" _log "github.com/sirupsen/logrus" ) @@ -90,25 +88,3 @@ func ReaderToBuffer(r io.Reader, limit int64) ([]byte, io.Reader, error) { // 返回 buffer return buf.Bytes(), nil, nil } - -// Wireshark 解析 https 设置 -var tlsKeyLogWriter io.Writer -var tlsKeyLogOnce sync.Once - -func GetTlsKeyLogWriter() io.Writer { - tlsKeyLogOnce.Do(func() { - logfile := os.Getenv("SSLKEYLOGFILE") - if logfile == "" { - return - } - - writer, err := os.OpenFile(logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) - if err != nil { - log.WithField("in", "GetTlsKeyLogWriter").Debug(err) - return - } - - tlsKeyLogWriter = writer - }) - return tlsKeyLogWriter -} diff --git a/proxy/middle.go b/proxy/middle.go index 16bcca6..c00d866 100644 --- a/proxy/middle.go +++ b/proxy/middle.go @@ -154,6 +154,7 @@ func (m *Middle) intercept(serverConn *connBuf) { if buf[0] == 0x16 && buf[1] == 0x03 && buf[2] <= 0x03 { // tls serverConn.connContext.Client.Tls = true + serverConn.connContext.InitHttpsServer(m.Proxy.Opts.SslInsecure) m.Listener.(*listener).connChan <- serverConn } else { // ws diff --git a/proxy/proxy.go b/proxy/proxy.go index 90dc5c4..d586833 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -3,7 +3,6 @@ package proxy import ( "bytes" "context" - "crypto/tls" "io" "net" "net/http" @@ -19,24 +18,24 @@ var log = _log.WithField("at", "proxy") type Options struct { Addr string - StreamLargeBodies int64 + StreamLargeBodies int64 // 当请求或响应体大于此字节时,转为 stream 模式 SslInsecure bool CaRootPath string } type Proxy struct { - Version string - Server *http.Server - Client *http.Client - Interceptor Interceptor - StreamLargeBodies int64 // 当请求或响应体大于此字节时,转为 stream 模式 - Addons []addon.Addon + Opts *Options + Version string + Server *http.Server + Interceptor Interceptor + Addons []addon.Addon activeConn map[net.Conn]*flow.ConnContext } func NewProxy(opts *Options) (*Proxy, error) { proxy := new(Proxy) + proxy.Opts = opts proxy.Version = "0.2.0" proxy.Server = &http.Server{ @@ -65,41 +64,14 @@ func NewProxy(opts *Options) (*Proxy, error) { }, } - proxy.Client = &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 15 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, - MaxIdleConns: 100, - IdleConnTimeout: 5 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - ForceAttemptHTTP2: false, // disable http2 - DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true. - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: opts.SslInsecure, - KeyLogWriter: GetTlsKeyLogWriter(), - }, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - // 禁止自动重定向 - return http.ErrUseLastResponse - }, - } - interceptor, err := NewMiddle(proxy, opts.CaRootPath) if err != nil { return nil, err } proxy.Interceptor = interceptor - if opts.StreamLargeBodies > 0 { - proxy.StreamLargeBodies = opts.StreamLargeBodies - } else { - proxy.StreamLargeBodies = 1024 * 1024 * 5 // default: 5mb + if opts.StreamLargeBodies <= 0 { + opts.StreamLargeBodies = 1024 * 1024 * 5 // default: 5mb } proxy.Addons = make([]addon.Addon, 0) @@ -201,7 +173,7 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { // Read request body var reqBody io.Reader = req.Body if !f.Stream { - reqBuf, r, err := ReaderToBuffer(req.Body, proxy.StreamLargeBodies) + reqBuf, r, err := ReaderToBuffer(req.Body, proxy.Opts.StreamLargeBodies) reqBody = r if err != nil { log.Error(err) @@ -210,7 +182,7 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { } if reqBuf == nil { - log.Warnf("request body size >= %v\n", proxy.StreamLargeBodies) + log.Warnf("request body size >= %v\n", proxy.Opts.StreamLargeBodies) f.Stream = true } else { f.Request.Body = reqBuf @@ -239,7 +211,10 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { proxyReq.Header.Add(key, v) } } - proxyRes, err := proxy.Client.Do(proxyReq) + + f.ConnContext.InitHttpServer(proxy.Opts.SslInsecure) + + proxyRes, err := f.ConnContext.Server.Client.Do(proxyReq) if err != nil { LogErr(log, err) res.WriteHeader(502) @@ -264,7 +239,7 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { // Read response body var resBody io.Reader = proxyRes.Body if !f.Stream { - resBuf, r, err := ReaderToBuffer(proxyRes.Body, proxy.StreamLargeBodies) + resBuf, r, err := ReaderToBuffer(proxyRes.Body, proxy.Opts.StreamLargeBodies) resBody = r if err != nil { log.Error(err) @@ -272,7 +247,7 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { return } if resBuf == nil { - log.Warnf("response body size >= %v\n", proxy.StreamLargeBodies) + log.Warnf("response body size >= %v\n", proxy.Opts.StreamLargeBodies) f.Stream = true } else { f.Response.Body = resBuf @@ -327,10 +302,13 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { } func (proxy *Proxy) whenClientConnClose(c net.Conn) { - client := proxy.activeConn[c].Client + connCtx := proxy.activeConn[c] + for _, addon := range proxy.Addons { - addon.ClientDisconnected(client) + addon.ClientDisconnected(connCtx.Client) } + connCtx.Server.Client.CloseIdleConnections() + delete(proxy.activeConn, c) }