diff --git a/proxy/connection.go b/proxy/connection.go index ae9962c..9fde9bb 100644 --- a/proxy/connection.go +++ b/proxy/connection.go @@ -48,7 +48,8 @@ type ConnContext struct { ClientConn *ClientConn ServerConn *ServerConn - proxy *Proxy + proxy *Proxy + pipeConn *pipeConn } func newConnContext(c net.Conn, proxy *Proxy) *ConnContext { @@ -105,36 +106,38 @@ func (connCtx *ConnContext) initHttpServerConn() { connCtx.ServerConn = serverConn } -func (connCtx *ConnContext) initHttpsServerConn() { - if connCtx.ServerConn != nil { - return +func (connCtx *ConnContext) initServerTcpConn() error { + log.Debugln("in initServerTcpConn") + ServerConn := newServerConn() + connCtx.ServerConn = ServerConn + ServerConn.Address = connCtx.pipeConn.host + + plainConn, err := (&net.Dialer{}).DialContext(context.Background(), "tcp", ServerConn.Address) + if err != nil { + return err + } + ServerConn.Conn = &wrapServerConn{ + Conn: plainConn, + proxy: connCtx.proxy, + connCtx: connCtx, + } + + for _, addon := range connCtx.proxy.Addons { + addon.ServerConnected(connCtx) } + + return nil +} + +func (connCtx *ConnContext) initHttpsServerConn() { if !connCtx.ClientConn.Tls { return } - - ServerConn := newServerConn() - ServerConn.client = &http.Client{ + connCtx.ServerConn.client = &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { log.Debugln("in https DialTLSContext") - plainConn, err := (&net.Dialer{}).DialContext(ctx, network, addr) - if err != nil { - return nil, err - } - cw := &wrapServerConn{ - Conn: plainConn, - proxy: connCtx.proxy, - connCtx: connCtx, - } - ServerConn.Conn = cw - ServerConn.Address = addr - - for _, addon := range connCtx.proxy.Addons { - addon.ServerConnected(connCtx) - } - firstTLSHost, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -144,7 +147,7 @@ func (connCtx *ConnContext) initHttpsServerConn() { KeyLogWriter: getTlsKeyLogWriter(), ServerName: firstTLSHost, } - tlsConn := tls.Client(cw, cfg) + tlsConn := tls.Client(connCtx.ServerConn.Conn, cfg) return tlsConn, nil }, ForceAttemptHTTP2: false, // disable http2 @@ -155,7 +158,6 @@ func (connCtx *ConnContext) initHttpsServerConn() { return http.ErrUseLastResponse }, } - connCtx.ServerConn = ServerConn } // wrap tcpConn for remote client diff --git a/proxy/interceptor.go b/proxy/interceptor.go index 4c8306d..9984225 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -44,19 +44,22 @@ func (a *pipeAddr) String() string { return a.remoteAddr } type pipeConn struct { net.Conn r *bufio.Reader - host string - remoteAddr string + host string // server host:port + remoteAddr string // client ip:port connContext *ConnContext } func newPipeConn(c net.Conn, req *http.Request) *pipeConn { - return &pipeConn{ + connContext := req.Context().Value(connContextKey).(*ConnContext) + pipeConn := &pipeConn{ Conn: c, r: bufio.NewReader(c), host: req.Host, remoteAddr: req.RemoteAddr, - connContext: req.Context().Value(connContextKey).(*ConnContext), + connContext: connContext, } + connContext.pipeConn = pipeConn + return pipeConn } func (c *pipeConn) Peek(n int) ([]byte, error) { @@ -116,9 +119,9 @@ func newMiddle(proxy *Proxy) (interceptor, error) { }, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 TLSConfig: &tls.Config{ - GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { - log.Debugf("middle GetCertificate ServerName: %v\n", chi.ServerName) - return ca.GetCert(chi.ServerName) + GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + log.Debugf("middle GetCertificate ServerName: %v\n", clientHello.ServerName) + return ca.GetCert(clientHello.ServerName) }, }, } @@ -130,9 +133,14 @@ func (m *middle) Start() error { return m.server.ServeTLS(m.listener, "", "") } -// todo: should block until ServerConnected func (m *middle) Dial(req *http.Request) (net.Conn, error) { pipeClientConn, pipeServerConn := newPipes(req) + err := pipeServerConn.connContext.initServerTcpConn() + if err != nil { + pipeClientConn.Close() + pipeServerConn.Close() + return nil, err + } go m.intercept(pipeServerConn) return pipeClientConn, nil }