fix Race conditions

addon-dailer
lqqyt2423 4 years ago
parent 8b18fe465c
commit 6807cce67d

@ -15,6 +15,7 @@ import (
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"sync"
"time" "time"
"github.com/golang/groupcache/lru" "github.com/golang/groupcache/lru"
@ -37,6 +38,8 @@ type CA struct {
cache *lru.Cache cache *lru.Cache
group *singleflight.Group group *singleflight.Group
cacheMu sync.Mutex
} }
func NewCA(path string) (*CA, error) { func NewCA(path string) (*CA, error) {
@ -246,15 +249,20 @@ func (ca *CA) saveCert() error {
} }
func (ca *CA) GetCert(commonName string) (*tls.Certificate, error) { func (ca *CA) GetCert(commonName string) (*tls.Certificate, error) {
ca.cacheMu.Lock()
if val, ok := ca.cache.Get(commonName); ok { if val, ok := ca.cache.Get(commonName); ok {
ca.cacheMu.Unlock()
log.WithField("commonName", commonName).Debug("GetCert") log.WithField("commonName", commonName).Debug("GetCert")
return val.(*tls.Certificate), nil return val.(*tls.Certificate), nil
} }
ca.cacheMu.Unlock()
val, err := ca.group.Do(commonName, func() (interface{}, error) { val, err := ca.group.Do(commonName, func() (interface{}, error) {
cert, err := ca.DummyCert(commonName) cert, err := ca.DummyCert(commonName)
if err == nil { if err == nil {
ca.cacheMu.Lock()
ca.cache.Add(commonName, cert) ca.cache.Add(commonName, cert)
ca.cacheMu.Unlock()
} }
return cert, err return cert, err
}) })

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
mock_conn "github.com/jordwest/mock-conn" mock_conn "github.com/jordwest/mock-conn"
@ -34,11 +35,12 @@ type conn struct {
// connection: keep-alive 相关 // connection: keep-alive 相关
readCanCancel bool // 是否可取消 Read readCanCancel bool // 是否可取消 Read
firstRead bool // 首次调用 Read 初始化 firstRead bool // 首次调用 Read 初始化
pendingRead bool // 当前是否有 Read 操作在阻塞中
readErrChan chan error // Read 方法提前返回时的错误,总是 os.ErrDeadlineExceeded readErrChan chan error // Read 方法提前返回时的错误,总是 os.ErrDeadlineExceeded
readErr error // 底层 End 返回的错误 readErr error // 底层 End 返回的错误
readDeadline time.Time // SetReadDeadline 设置的时间 readDeadline time.Time // SetReadDeadline 设置的时间
chunk chan []byte // Read 和 beginRead 的交互 channel chunk chan []byte // Read 和 beginRead 的交互 channel
readDeadlineMu sync.RWMutex
} }
var connUnexpected = errors.New("unexpected read error") var connUnexpected = errors.New("unexpected read error")
@ -86,19 +88,18 @@ func (c *conn) Read(data []byte) (int, error) {
} }
c.firstRead = true c.firstRead = true
c.readDeadlineMu.RLock()
if !c.readDeadline.Equal(time.Time{}) { if !c.readDeadline.Equal(time.Time{}) {
if !c.readDeadline.After(time.Now()) { if !c.readDeadline.After(time.Now()) {
c.readDeadlineMu.RUnlock()
return 0, os.ErrDeadlineExceeded return 0, os.ErrDeadlineExceeded
} else { } else {
c.readDeadlineMu.RUnlock()
log.WithField("host", c.host).Warnf("c.readDeadline is future %v\n", c.readDeadline) log.WithField("host", c.host).Warnf("c.readDeadline is future %v\n", c.readDeadline)
return 0, connUnexpected return 0, connUnexpected
} }
} }
c.readDeadlineMu.RUnlock()
c.pendingRead = true
defer func() {
c.pendingRead = false
}()
select { select {
case err := <-c.readErrChan: case err := <-c.readErrChan:
@ -117,10 +118,13 @@ func (c *conn) SetDeadline(t time.Time) error {
return connUnexpected return connUnexpected
} }
// http server 标准库实现时,当多个 http 复用底层 socke 时,会调用此方法 // http server 标准库实现时,当多个 http 复用底层 socket 时,会调用此方法
func (c *conn) SetReadDeadline(t time.Time) error { func (c *conn) SetReadDeadline(t time.Time) error {
c.readDeadlineMu.Lock()
c.readDeadline = t c.readDeadline = t
if c.pendingRead && !t.Equal(time.Time{}) && !t.After(time.Now()) { c.readDeadlineMu.Unlock()
if !t.Equal(time.Time{}) && !t.After(time.Now()) {
c.readErrChan <- os.ErrDeadlineExceeded c.readErrChan <- os.ErrDeadlineExceeded
} }
return nil return nil
@ -180,12 +184,12 @@ func NewMiddle(proxy *Proxy) (Interceptor, error) {
} }
m.Server = server m.Server = server
m.Listener = &listener{make(chan net.Conn)}
return m, nil return m, nil
} }
func (m *Middle) Start() error { func (m *Middle) Start() error {
m.Listener = &listener{make(chan net.Conn)}
return m.Server.ServeTLS(m.Listener, "", "") return m.Server.ServeTLS(m.Listener, "", "")
} }

Loading…
Cancel
Save