diff --git a/proxy/connection.go b/proxy/connection.go index 54750e7..a5e54a7 100644 --- a/proxy/connection.go +++ b/proxy/connection.go @@ -160,7 +160,7 @@ func (connCtx *ConnContext) initHttpsServerConn() { } } -func (connCtx *ConnContext) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (connCtx *ConnContext) tlsHandshake(clientHello *tls.ClientHelloInfo) error { cfg := &tls.Config{ InsecureSkipVerify: connCtx.proxy.Opts.SslInsecure, KeyLogWriter: getTlsKeyLogWriter(), @@ -189,7 +189,7 @@ func (connCtx *ConnContext) getCertificate(clientHello *tls.ClientHelloInfo) (*t if err != nil { connCtx.ServerConn.tlsHandshakeErr = err close(connCtx.ServerConn.tlsHandshaked) - return nil, err + return err } connCtx.ServerConn.tlsConn = tlsConn @@ -197,8 +197,7 @@ func (connCtx *ConnContext) getCertificate(clientHello *tls.ClientHelloInfo) (*t connCtx.ServerConn.tlsState = &tlsState close(connCtx.ServerConn.tlsHandshaked) - // todo: change here - return connCtx.proxy.interceptor.(*middle).ca.GetCert(clientHello.ServerName) + return nil } // wrap tcpConn for remote client diff --git a/proxy/interceptor.go b/proxy/interceptor.go index fc0004b..5044f79 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -122,7 +122,10 @@ func newMiddle(proxy *Proxy) (interceptor, error) { SessionTicketsDisabled: true, // 设置此值为 true ,确保每次都会调用下面的 GetCertificate 方法 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { connCtx := clientHello.Context().Value(connContextKey).(*ConnContext) - return connCtx.getCertificate(clientHello) + if err := connCtx.tlsHandshake(clientHello); err != nil { + return nil, err + } + return ca.GetCert(clientHello.ServerName) }, }, }