diff --git a/flow/conncontext.go b/flow/conncontext.go index cbdfeb1..0165d01 100644 --- a/flow/conncontext.go +++ b/flow/conncontext.go @@ -23,7 +23,7 @@ func NewConnContext(c net.Conn) *ConnContext { } } -func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.Conn) net.Conn, whenServerConnected func()) { +func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.Conn) net.Conn, whenConnected func()) { if connCtx.Server != nil { return } @@ -48,7 +48,7 @@ func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.C cw := connWrap(c) server.Conn = cw - defer whenServerConnected() + defer whenConnected() return cw, nil }, ForceAttemptHTTP2: false, // disable http2 @@ -67,7 +67,7 @@ func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.C connCtx.Server = server } -func (connCtx *ConnContext) InitHttpsServer(sslInsecure bool) { +func (connCtx *ConnContext) InitHttpsServer(sslInsecure bool, connWrap func(net.Conn) net.Conn, whenConnected func()) { if connCtx.Server != nil { return } @@ -80,18 +80,33 @@ func (connCtx *ConnContext) InitHttpsServer(sslInsecure bool) { Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, - // todo: change here - DialContext: (&net.Dialer{ - // Timeout: 30 * time.Second, - // KeepAlive: 30 * time.Second, - }).DialContext, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + log.Debugln("in https DialTLSContext") + + plainConn, err := (&net.Dialer{}).DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + cw := connWrap(plainConn) + server.Conn = cw + whenConnected() + + firstTLSHost, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + cfg := &tls.Config{ + InsecureSkipVerify: sslInsecure, + KeyLogWriter: GetTlsKeyLogWriter(), + ServerName: firstTLSHost, + } + tlsConn := tls.Client(cw, cfg) + return tlsConn, nil + }, ForceAttemptHTTP2: false, // disable http2 DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true. - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: sslInsecure, - KeyLogWriter: GetTlsKeyLogWriter(), - }, }, CheckRedirect: func(req *http.Request, via []*http.Request) error { // 禁止自动重定向 diff --git a/proxy/middle.go b/proxy/middle.go index 7f9cc02..b71ce45 100644 --- a/proxy/middle.go +++ b/proxy/middle.go @@ -32,14 +32,14 @@ func (pipeAddr) Network() string { return "pipe" } func (a *pipeAddr) String() string { return a.remoteAddr } // 建立客户端和服务端通信的通道 -func newPipes(req *http.Request) (net.Conn, *connBuf) { +func newPipes(req *http.Request) (net.Conn, *pipeConn) { client, srv := net.Pipe() - server := newConnBuf(srv, req) + server := newPipeConn(srv, req) return client, server } // add Peek method for conn -type connBuf struct { +type pipeConn struct { net.Conn r *bufio.Reader host string @@ -47,8 +47,8 @@ type connBuf struct { connContext *flow.ConnContext } -func newConnBuf(c net.Conn, req *http.Request) *connBuf { - return &connBuf{ +func newPipeConn(c net.Conn, req *http.Request) *pipeConn { + return &pipeConn{ Conn: c, r: bufio.NewReader(c), host: req.Host, @@ -57,16 +57,16 @@ func newConnBuf(c net.Conn, req *http.Request) *connBuf { } } -func (b *connBuf) Peek(n int) ([]byte, error) { - return b.r.Peek(n) +func (c *pipeConn) Peek(n int) ([]byte, error) { + return c.r.Peek(n) } -func (b *connBuf) Read(data []byte) (int, error) { - return b.r.Read(data) +func (c *pipeConn) Read(data []byte) (int, error) { + return c.r.Read(data) } -func (b *connBuf) RemoteAddr() net.Addr { - return &pipeAddr{remoteAddr: b.remoteAddr} +func (c *pipeConn) RemoteAddr() net.Addr { + return &pipeAddr{remoteAddr: c.remoteAddr} } // Middle: man-in-the-middle @@ -93,7 +93,7 @@ func NewMiddle(proxy *Proxy, caPath string) (Interceptor, error) { IdleTimeout: 5 * time.Second, ConnContext: func(ctx context.Context, c net.Conn) context.Context { - return context.WithValue(ctx, flow.ConnContextKey, c.(*tls.Conn).NetConn().(*connBuf).connContext) + return context.WithValue(ctx, flow.ConnContextKey, c.(*tls.Conn).NetConn().(*pipeConn).connContext) }, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 @@ -116,9 +116,9 @@ func (m *Middle) Start() error { } func (m *Middle) Dial(req *http.Request) (net.Conn, error) { - clientConn, serverConn := newPipes(req) - go m.intercept(serverConn) - return clientConn, nil + pipeClientConn, pipeServerConn := newPipes(req) + go m.intercept(pipeServerConn) + return pipeClientConn, nil } func (m *Middle) ServeHTTP(res http.ResponseWriter, req *http.Request) { @@ -140,24 +140,38 @@ func (m *Middle) ServeHTTP(res http.ResponseWriter, req *http.Request) { // 解析 connect 流量 // 如果是 tls 流量,则进入 listener.Accept => Middle.ServeHTTP // 否则很可能是 ws 流量 -func (m *Middle) intercept(serverConn *connBuf) { - log := log.WithField("in", "Middle.intercept").WithField("host", serverConn.host) +func (m *Middle) intercept(pipeServerConn *pipeConn) { + log := log.WithField("in", "Middle.intercept").WithField("host", pipeServerConn.host) - buf, err := serverConn.Peek(3) + buf, err := pipeServerConn.Peek(3) if err != nil { log.Errorf("Peek error: %v\n", err) - serverConn.Close() + pipeServerConn.Close() return } // https://github.com/mitmproxy/mitmproxy/blob/main/mitmproxy/net/tls.py is_tls_record_magic if buf[0] == 0x16 && buf[1] == 0x03 && buf[2] <= 0x03 { // tls - serverConn.connContext.Client.Tls = true - serverConn.connContext.InitHttpsServer(m.Proxy.Opts.SslInsecure) - m.Listener.(*middleListener).connChan <- serverConn + pipeServerConn.connContext.Client.Tls = true + pipeServerConn.connContext.InitHttpsServer( + m.Proxy.Opts.SslInsecure, + func(c net.Conn) net.Conn { + return &serverConn{ + Conn: c, + proxy: m.Proxy, + connCtx: pipeServerConn.connContext, + } + }, + func() { + for _, addon := range m.Proxy.Addons { + addon.ServerConnected(pipeServerConn.connContext) + } + }, + ) + m.Listener.(*middleListener).connChan <- pipeServerConn } else { // ws - DefaultWebSocket.WS(serverConn, serverConn.host) + DefaultWebSocket.WS(pipeServerConn, pipeServerConn.host) } }