optimize tlsHandshaked

addon-dailer
lqqyt2423 2 years ago
parent 637f752a3d
commit 896ea2997f

@ -31,15 +31,25 @@ type ServerConn struct {
Address string
Conn net.Conn
client *http.Client
tlsHandshaked chan struct{}
tlsHandshakeErr error
tlsConn *tls.Conn
tlsState *tls.ConnectionState
client *http.Client
}
func newServerConn() *ServerConn {
return &ServerConn{
Id: uuid.NewV4(),
Id: uuid.NewV4(),
tlsHandshaked: make(chan struct{}),
}
}
func (c *ServerConn) TlsState() *tls.ConnectionState {
<-c.tlsHandshaked
return c.tlsState
}
// connection context ctx key
var connContextKey = new(struct{})
@ -137,18 +147,8 @@ func (connCtx *ConnContext) initHttpsServerConn() {
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
log.Debugln("in https DialTLSContext")
firstTLSHost, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
cfg := &tls.Config{
InsecureSkipVerify: connCtx.proxy.Opts.SslInsecure,
KeyLogWriter: getTlsKeyLogWriter(),
ServerName: firstTLSHost,
}
tlsConn := tls.Client(connCtx.ServerConn.Conn, cfg)
return tlsConn, nil
<-connCtx.ServerConn.tlsHandshaked
return connCtx.ServerConn.tlsConn, connCtx.ServerConn.tlsHandshakeErr
},
ForceAttemptHTTP2: false, // disable http2
DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true.
@ -160,6 +160,47 @@ func (connCtx *ConnContext) initHttpsServerConn() {
}
}
func (connCtx *ConnContext) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cfg := &tls.Config{
InsecureSkipVerify: connCtx.proxy.Opts.SslInsecure,
KeyLogWriter: getTlsKeyLogWriter(),
ServerName: clientHello.ServerName,
NextProtos: []string{"http/1.1"}, // todo: h2
// CurvePreferences: clientHello.SupportedCurves, // todo: 如果打开会出错
CipherSuites: clientHello.CipherSuites,
}
if len(clientHello.SupportedVersions) > 0 {
minVersion := clientHello.SupportedVersions[0]
maxVersion := clientHello.SupportedVersions[0]
for _, version := range clientHello.SupportedVersions {
if version < minVersion {
minVersion = version
}
if version > maxVersion {
maxVersion = version
}
}
cfg.MinVersion = minVersion
cfg.MaxVersion = maxVersion
}
tlsConn := tls.Client(connCtx.ServerConn.Conn, cfg)
err := tlsConn.HandshakeContext(context.Background())
if err != nil {
connCtx.ServerConn.tlsHandshakeErr = err
close(connCtx.ServerConn.tlsHandshaked)
return nil, err
}
connCtx.ServerConn.tlsConn = tlsConn
tlsState := tlsConn.ConnectionState()
connCtx.ServerConn.tlsState = &tlsState
close(connCtx.ServerConn.tlsHandshaked)
// todo: change here
return connCtx.proxy.interceptor.(*middle).ca.GetCert(clientHello.ServerName)
}
// wrap tcpConn for remote client
type wrapClientConn struct {
net.Conn

@ -119,9 +119,10 @@ func newMiddle(proxy *Proxy) (interceptor, error) {
},
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
TLSConfig: &tls.Config{
SessionTicketsDisabled: true, // 设置此值为 true ,确保每次都会调用下面的 GetCertificate 方法
GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
log.Debugf("middle GetCertificate ServerName: %v\n", clientHello.ServerName)
return ca.GetCert(clientHello.ServerName)
connCtx := clientHello.Context().Value(connContextKey).(*ConnContext)
return connCtx.getCertificate(clientHello)
},
},
}

Loading…
Cancel
Save