From 896ea2997fadaf07124692e0f02d8f73278cc946 Mon Sep 17 00:00:00 2001 From: lqqyt2423 <974923609@qq.com> Date: Sat, 18 Jun 2022 22:42:45 +0800 Subject: [PATCH] optimize tlsHandshaked --- proxy/connection.go | 69 +++++++++++++++++++++++++++++++++++--------- proxy/interceptor.go | 5 ++-- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/proxy/connection.go b/proxy/connection.go index 9fde9bb..54750e7 100644 --- a/proxy/connection.go +++ b/proxy/connection.go @@ -31,15 +31,25 @@ type ServerConn struct { Address string Conn net.Conn - client *http.Client + tlsHandshaked chan struct{} + tlsHandshakeErr error + tlsConn *tls.Conn + tlsState *tls.ConnectionState + client *http.Client } func newServerConn() *ServerConn { return &ServerConn{ - Id: uuid.NewV4(), + Id: uuid.NewV4(), + tlsHandshaked: make(chan struct{}), } } +func (c *ServerConn) TlsState() *tls.ConnectionState { + <-c.tlsHandshaked + return c.tlsState +} + // connection context ctx key var connContextKey = new(struct{}) @@ -137,18 +147,8 @@ func (connCtx *ConnContext) initHttpsServerConn() { Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - log.Debugln("in https DialTLSContext") - firstTLSHost, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - cfg := &tls.Config{ - InsecureSkipVerify: connCtx.proxy.Opts.SslInsecure, - KeyLogWriter: getTlsKeyLogWriter(), - ServerName: firstTLSHost, - } - tlsConn := tls.Client(connCtx.ServerConn.Conn, cfg) - return tlsConn, nil + <-connCtx.ServerConn.tlsHandshaked + return connCtx.ServerConn.tlsConn, connCtx.ServerConn.tlsHandshakeErr }, ForceAttemptHTTP2: false, // disable http2 DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true. @@ -160,6 +160,47 @@ func (connCtx *ConnContext) initHttpsServerConn() { } } +func (connCtx *ConnContext) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cfg := &tls.Config{ + InsecureSkipVerify: connCtx.proxy.Opts.SslInsecure, + KeyLogWriter: getTlsKeyLogWriter(), + ServerName: clientHello.ServerName, + NextProtos: []string{"http/1.1"}, // todo: h2 + // CurvePreferences: clientHello.SupportedCurves, // todo: 如果打开会出错 + CipherSuites: clientHello.CipherSuites, + } + if len(clientHello.SupportedVersions) > 0 { + minVersion := clientHello.SupportedVersions[0] + maxVersion := clientHello.SupportedVersions[0] + for _, version := range clientHello.SupportedVersions { + if version < minVersion { + minVersion = version + } + if version > maxVersion { + maxVersion = version + } + } + cfg.MinVersion = minVersion + cfg.MaxVersion = maxVersion + } + + tlsConn := tls.Client(connCtx.ServerConn.Conn, cfg) + err := tlsConn.HandshakeContext(context.Background()) + if err != nil { + connCtx.ServerConn.tlsHandshakeErr = err + close(connCtx.ServerConn.tlsHandshaked) + return nil, err + } + + connCtx.ServerConn.tlsConn = tlsConn + tlsState := tlsConn.ConnectionState() + connCtx.ServerConn.tlsState = &tlsState + close(connCtx.ServerConn.tlsHandshaked) + + // todo: change here + return connCtx.proxy.interceptor.(*middle).ca.GetCert(clientHello.ServerName) +} + // wrap tcpConn for remote client type wrapClientConn struct { net.Conn diff --git a/proxy/interceptor.go b/proxy/interceptor.go index 9984225..fc0004b 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -119,9 +119,10 @@ func newMiddle(proxy *Proxy) (interceptor, error) { }, TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 TLSConfig: &tls.Config{ + SessionTicketsDisabled: true, // 设置此值为 true ,确保每次都会调用下面的 GetCertificate 方法 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - log.Debugf("middle GetCertificate ServerName: %v\n", clientHello.ServerName) - return ca.GetCert(clientHello.ServerName) + connCtx := clientHello.Context().Value(connContextKey).(*ConnContext) + return connCtx.getCertificate(clientHello) }, }, }