From 6807cce67d97d4aee344ad40d299c123d2cef251 Mon Sep 17 00:00:00 2001 From: lqqyt2423 <974923609@qq.com> Date: Thu, 4 Feb 2021 16:46:47 +0800 Subject: [PATCH] fix Race conditions --- cert/cert.go | 8 ++++++++ proxy/middle.go | 22 +++++++++++++--------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/cert/cert.go b/cert/cert.go index a567414..6934008 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -15,6 +15,7 @@ import ( "net" "os" "path/filepath" + "sync" "time" "github.com/golang/groupcache/lru" @@ -37,6 +38,8 @@ type CA struct { cache *lru.Cache group *singleflight.Group + + cacheMu sync.Mutex } func NewCA(path string) (*CA, error) { @@ -246,15 +249,20 @@ func (ca *CA) saveCert() error { } func (ca *CA) GetCert(commonName string) (*tls.Certificate, error) { + ca.cacheMu.Lock() if val, ok := ca.cache.Get(commonName); ok { + ca.cacheMu.Unlock() log.WithField("commonName", commonName).Debug("GetCert") return val.(*tls.Certificate), nil } + ca.cacheMu.Unlock() val, err := ca.group.Do(commonName, func() (interface{}, error) { cert, err := ca.DummyCert(commonName) if err == nil { + ca.cacheMu.Lock() ca.cache.Add(commonName, cert) + ca.cacheMu.Unlock() } return cert, err }) diff --git a/proxy/middle.go b/proxy/middle.go index 84cd367..c373925 100644 --- a/proxy/middle.go +++ b/proxy/middle.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "strings" + "sync" "time" mock_conn "github.com/jordwest/mock-conn" @@ -34,11 +35,12 @@ type conn struct { // connection: keep-alive 相关 readCanCancel bool // 是否可取消 Read firstRead bool // 首次调用 Read 初始化 - pendingRead bool // 当前是否有 Read 操作在阻塞中 readErrChan chan error // Read 方法提前返回时的错误,总是 os.ErrDeadlineExceeded readErr error // 底层 End 返回的错误 readDeadline time.Time // SetReadDeadline 设置的时间 chunk chan []byte // Read 和 beginRead 的交互 channel + + readDeadlineMu sync.RWMutex } var connUnexpected = errors.New("unexpected read error") @@ -86,19 +88,18 @@ func (c *conn) Read(data []byte) (int, error) { } c.firstRead = true + c.readDeadlineMu.RLock() if !c.readDeadline.Equal(time.Time{}) { if !c.readDeadline.After(time.Now()) { + c.readDeadlineMu.RUnlock() return 0, os.ErrDeadlineExceeded } else { + c.readDeadlineMu.RUnlock() log.WithField("host", c.host).Warnf("c.readDeadline is future %v\n", c.readDeadline) return 0, connUnexpected } } - - c.pendingRead = true - defer func() { - c.pendingRead = false - }() + c.readDeadlineMu.RUnlock() select { case err := <-c.readErrChan: @@ -117,10 +118,13 @@ func (c *conn) SetDeadline(t time.Time) error { return connUnexpected } -// http server 标准库实现时,当多个 http 复用底层 socke 时,会调用此方法 +// http server 标准库实现时,当多个 http 复用底层 socket 时,会调用此方法 func (c *conn) SetReadDeadline(t time.Time) error { + c.readDeadlineMu.Lock() c.readDeadline = t - if c.pendingRead && !t.Equal(time.Time{}) && !t.After(time.Now()) { + c.readDeadlineMu.Unlock() + + if !t.Equal(time.Time{}) && !t.After(time.Now()) { c.readErrChan <- os.ErrDeadlineExceeded } return nil @@ -180,12 +184,12 @@ func NewMiddle(proxy *Proxy) (Interceptor, error) { } m.Server = server + m.Listener = &listener{make(chan net.Conn)} return m, nil } func (m *Middle) Start() error { - m.Listener = &listener{make(chan net.Conn)} return m.Server.ServeTLS(m.Listener, "", "") }