optimize middle.Dial

addon-dailer
lqqyt2423 2 years ago
parent 7f55903797
commit 637f752a3d

@ -48,7 +48,8 @@ type ConnContext struct {
ClientConn *ClientConn
ServerConn *ServerConn
proxy *Proxy
proxy *Proxy
pipeConn *pipeConn
}
func newConnContext(c net.Conn, proxy *Proxy) *ConnContext {
@ -105,36 +106,38 @@ func (connCtx *ConnContext) initHttpServerConn() {
connCtx.ServerConn = serverConn
}
func (connCtx *ConnContext) initHttpsServerConn() {
if connCtx.ServerConn != nil {
return
func (connCtx *ConnContext) initServerTcpConn() error {
log.Debugln("in initServerTcpConn")
ServerConn := newServerConn()
connCtx.ServerConn = ServerConn
ServerConn.Address = connCtx.pipeConn.host
plainConn, err := (&net.Dialer{}).DialContext(context.Background(), "tcp", ServerConn.Address)
if err != nil {
return err
}
ServerConn.Conn = &wrapServerConn{
Conn: plainConn,
proxy: connCtx.proxy,
connCtx: connCtx,
}
for _, addon := range connCtx.proxy.Addons {
addon.ServerConnected(connCtx)
}
return nil
}
func (connCtx *ConnContext) initHttpsServerConn() {
if !connCtx.ClientConn.Tls {
return
}
ServerConn := newServerConn()
ServerConn.client = &http.Client{
connCtx.ServerConn.client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
log.Debugln("in https DialTLSContext")
plainConn, err := (&net.Dialer{}).DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
cw := &wrapServerConn{
Conn: plainConn,
proxy: connCtx.proxy,
connCtx: connCtx,
}
ServerConn.Conn = cw
ServerConn.Address = addr
for _, addon := range connCtx.proxy.Addons {
addon.ServerConnected(connCtx)
}
firstTLSHost, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
@ -144,7 +147,7 @@ func (connCtx *ConnContext) initHttpsServerConn() {
KeyLogWriter: getTlsKeyLogWriter(),
ServerName: firstTLSHost,
}
tlsConn := tls.Client(cw, cfg)
tlsConn := tls.Client(connCtx.ServerConn.Conn, cfg)
return tlsConn, nil
},
ForceAttemptHTTP2: false, // disable http2
@ -155,7 +158,6 @@ func (connCtx *ConnContext) initHttpsServerConn() {
return http.ErrUseLastResponse
},
}
connCtx.ServerConn = ServerConn
}
// wrap tcpConn for remote client

@ -44,19 +44,22 @@ func (a *pipeAddr) String() string { return a.remoteAddr }
type pipeConn struct {
net.Conn
r *bufio.Reader
host string
remoteAddr string
host string // server host:port
remoteAddr string // client ip:port
connContext *ConnContext
}
func newPipeConn(c net.Conn, req *http.Request) *pipeConn {
return &pipeConn{
connContext := req.Context().Value(connContextKey).(*ConnContext)
pipeConn := &pipeConn{
Conn: c,
r: bufio.NewReader(c),
host: req.Host,
remoteAddr: req.RemoteAddr,
connContext: req.Context().Value(connContextKey).(*ConnContext),
connContext: connContext,
}
connContext.pipeConn = pipeConn
return pipeConn
}
func (c *pipeConn) Peek(n int) ([]byte, error) {
@ -116,9 +119,9 @@ func newMiddle(proxy *Proxy) (interceptor, error) {
},
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
TLSConfig: &tls.Config{
GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
log.Debugf("middle GetCertificate ServerName: %v\n", chi.ServerName)
return ca.GetCert(chi.ServerName)
GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
log.Debugf("middle GetCertificate ServerName: %v\n", clientHello.ServerName)
return ca.GetCert(clientHello.ServerName)
},
},
}
@ -130,9 +133,14 @@ func (m *middle) Start() error {
return m.server.ServeTLS(m.listener, "", "")
}
// todo: should block until ServerConnected
func (m *middle) Dial(req *http.Request) (net.Conn, error) {
pipeClientConn, pipeServerConn := newPipes(req)
err := pipeServerConn.connContext.initServerTcpConn()
if err != nil {
pipeClientConn.Close()
pipeServerConn.Close()
return nil, err
}
go m.intercept(pipeServerConn)
return pipeClientConn, nil
}

Loading…
Cancel
Save