close connection correctly, try fix eof error

addon-dailer
lqqyt2423 2 years ago
parent 8265e52bea
commit 3360b35a2c

1
.gitignore vendored

@ -4,3 +4,4 @@
/dummycert /dummycert
/.idea /.idea
dist/ dist/
.vscode

@ -77,6 +77,7 @@ type ConnContext struct {
proxy *Proxy proxy *Proxy
pipeConn *pipeConn pipeConn *pipeConn
closeAfterResponse bool // after http response, http server will close the connection
} }
func newConnContext(c net.Conn, proxy *Proxy) *ConnContext { func newConnContext(c net.Conn, proxy *Proxy) *ConnContext {
@ -231,10 +232,10 @@ type wrapClientConn struct {
} }
func (c *wrapClientConn) Close() error { func (c *wrapClientConn) Close() error {
log.Debugln("in wrapClientConn close")
if c.closed { if c.closed {
return c.closeErr return c.closeErr
} }
log.Debugln("in wrapClientConn close")
c.closed = true c.closed = true
c.closeErr = c.Conn.Close() c.closeErr = c.Conn.Close()
@ -244,7 +245,7 @@ func (c *wrapClientConn) Close() error {
} }
if c.connCtx.ServerConn != nil && c.connCtx.ServerConn.Conn != nil { if c.connCtx.ServerConn != nil && c.connCtx.ServerConn.Conn != nil {
c.connCtx.ServerConn.Conn.Close() c.connCtx.ServerConn.Conn.(*wrapServerConn).Conn.(*net.TCPConn).CloseRead()
} }
return c.closeErr return c.closeErr
@ -278,10 +279,10 @@ type wrapServerConn struct {
} }
func (c *wrapServerConn) Close() error { func (c *wrapServerConn) Close() error {
log.Debugln("in wrapServerConn close")
if c.closed { if c.closed {
return c.closeErr return c.closeErr
} }
log.Debugln("in wrapServerConn close")
c.closed = true c.closed = true
c.closeErr = c.Conn.Close() c.closeErr = c.Conn.Close()
@ -290,7 +291,14 @@ func (c *wrapServerConn) Close() error {
addon.ServerDisconnected(c.connCtx) addon.ServerDisconnected(c.connCtx)
} }
c.connCtx.ClientConn.Conn.Close() if !c.connCtx.ClientConn.Tls {
c.connCtx.ClientConn.Conn.(*wrapClientConn).Conn.(*net.TCPConn).CloseRead()
} else {
// if keep-alive connection close
if !c.connCtx.closeAfterResponse {
c.connCtx.pipeConn.Close()
}
}
return c.closeErr return c.closeErr
} }

@ -99,6 +99,8 @@ type Response struct {
Body []byte `json:"-"` Body []byte `json:"-"`
BodyReader io.Reader BodyReader io.Reader
close bool // connection close
decodedBody []byte decodedBody []byte
decoded bool // decoded reports whether the response was sent compressed but was decoded to decodedBody. decoded bool // decoded reports whether the response was sent compressed but was decoded to decodedBody.
decodedErr error decodedErr error

@ -3,6 +3,7 @@ package proxy
import ( import (
"bytes" "bytes"
"io" "io"
"net"
"os" "os"
"strings" "strings"
"sync" "sync"
@ -38,28 +39,39 @@ func logErr(log *log.Entry, err error) (loged bool) {
} }
// 转发流量 // 转发流量
// Read a => Write b func transfer(log *log.Entry, server, client io.ReadWriteCloser) {
// Read b => Write a
func transfer(log *log.Entry, a, b io.ReadWriteCloser) {
done := make(chan struct{}) done := make(chan struct{})
defer close(done) defer close(done)
forward := func(dst io.WriteCloser, src io.Reader, ec chan<- error) { errChan := make(chan error)
_, err := io.Copy(dst, src) go func() {
_, err := io.Copy(server, client)
dst.Close() // 当一端读结束时,结束另一端的写 log.Debugln("client copy end", err)
client.Close()
select { select {
case <-done: case <-done:
return return
case ec <- err: case errChan <- err:
return return
} }
}()
go func() {
_, err := io.Copy(client, server)
log.Debugln("server copy end", err)
server.Close()
if clientConn, ok := client.(*wrapClientConn); ok {
err := clientConn.Conn.(*net.TCPConn).CloseRead()
log.Debugln("clientConn.Conn.(*net.TCPConn).CloseRead()", err)
} }
errChan := make(chan error) select {
go forward(a, b, errChan) case <-done:
go forward(b, a, errChan) return
case errChan <- err:
return
}
}()
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
if err := <-errChan; err != nil { if err := <-errChan; err != nil {

@ -124,6 +124,9 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
} }
} }
} }
if response.close {
res.Header().Add("Connection", "close")
}
res.WriteHeader(response.StatusCode) res.WriteHeader(response.StatusCode)
if body != nil { if body != nil {
@ -219,11 +222,17 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) {
res.WriteHeader(502) res.WriteHeader(502)
return return
} }
if proxyRes.Close {
f.ConnContext.closeAfterResponse = true
}
defer proxyRes.Body.Close() defer proxyRes.Body.Close()
f.Response = &Response{ f.Response = &Response{
StatusCode: proxyRes.StatusCode, StatusCode: proxyRes.StatusCode,
Header: proxyRes.Header, Header: proxyRes.Header,
close: proxyRes.Close,
} }
// trigger addon event Responseheaders // trigger addon event Responseheaders

@ -1,15 +1,16 @@
package proxy package proxy
import ( import (
"context"
"crypto/tls" "crypto/tls"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -68,40 +69,104 @@ func (addon *interceptAddon) Response(f *Flow) {
type testOrderAddon struct { type testOrderAddon struct {
BaseAddon BaseAddon
orders []string orders []string
mu sync.Mutex
}
func (addon *testOrderAddon) reset() {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = make([]string, 0)
}
func (addon *testOrderAddon) contains(t *testing.T, name string) {
t.Helper()
addon.mu.Lock()
defer addon.mu.Unlock()
for _, n := range addon.orders {
if name == n {
return
}
}
t.Fatalf("expected contains %s, but not", name)
}
func (addon *testOrderAddon) before(t *testing.T, a, b string) {
t.Helper()
addon.mu.Lock()
defer addon.mu.Unlock()
aIndex, bIndex := -1, -1
for i, n := range addon.orders {
if a == n {
aIndex = i
} else if b == n {
bIndex = i
}
}
if aIndex == -1 {
t.Fatalf("expected contains %s, but not", a)
}
if bIndex == -1 {
t.Fatalf("expected contains %s, but not", b)
}
if aIndex > bIndex {
t.Fatalf("expected %s executed before %s, but not", a, b)
}
} }
func (addon *testOrderAddon) ClientConnected(*ClientConn) { func (addon *testOrderAddon) ClientConnected(*ClientConn) {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "ClientConnected") addon.orders = append(addon.orders, "ClientConnected")
} }
func (addon *testOrderAddon) ClientDisconnected(*ClientConn) { func (addon *testOrderAddon) ClientDisconnected(*ClientConn) {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "ClientDisconnected") addon.orders = append(addon.orders, "ClientDisconnected")
} }
func (addon *testOrderAddon) ServerConnected(*ConnContext) { func (addon *testOrderAddon) ServerConnected(*ConnContext) {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "ServerConnected") addon.orders = append(addon.orders, "ServerConnected")
} }
func (addon *testOrderAddon) ServerDisconnected(*ConnContext) { func (addon *testOrderAddon) ServerDisconnected(*ConnContext) {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "ServerDisconnected") addon.orders = append(addon.orders, "ServerDisconnected")
} }
func (addon *testOrderAddon) TlsEstablishedServer(*ConnContext) { func (addon *testOrderAddon) TlsEstablishedServer(*ConnContext) {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "TlsEstablishedServer") addon.orders = append(addon.orders, "TlsEstablishedServer")
} }
func (addon *testOrderAddon) Requestheaders(*Flow) { func (addon *testOrderAddon) Requestheaders(*Flow) {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "Requestheaders") addon.orders = append(addon.orders, "Requestheaders")
} }
func (addon *testOrderAddon) Request(*Flow) { func (addon *testOrderAddon) Request(*Flow) {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "Request") addon.orders = append(addon.orders, "Request")
} }
func (addon *testOrderAddon) Responseheaders(*Flow) { func (addon *testOrderAddon) Responseheaders(*Flow) {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "Responseheaders") addon.orders = append(addon.orders, "Responseheaders")
} }
func (addon *testOrderAddon) Response(*Flow) { func (addon *testOrderAddon) Response(*Flow) {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "Response") addon.orders = append(addon.orders, "Response")
} }
func (addon *testOrderAddon) StreamRequestModifier(f *Flow, in io.Reader) io.Reader { func (addon *testOrderAddon) StreamRequestModifier(f *Flow, in io.Reader) io.Reader {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "StreamRequestModifier") addon.orders = append(addon.orders, "StreamRequestModifier")
return in return in
} }
func (addon *testOrderAddon) StreamResponseModifier(f *Flow, in io.Reader) io.Reader { func (addon *testOrderAddon) StreamResponseModifier(f *Flow, in io.Reader) io.Reader {
addon.mu.Lock()
defer addon.mu.Unlock()
addon.orders = append(addon.orders, "StreamResponseModifier") addon.orders = append(addon.orders, "StreamResponseModifier")
return in return in
} }
@ -167,7 +232,6 @@ func TestProxy(t *testing.T) {
SslInsecure: true, SslInsecure: true,
}) })
handleError(t, err) handleError(t, err)
testProxy.AddAddon(&LogAddon{})
testProxy.AddAddon(&interceptAddon{}) testProxy.AddAddon(&interceptAddon{})
testOrderAddonInstance := &testOrderAddon{ testOrderAddonInstance := &testOrderAddon{
orders: make([]string, 0), orders: make([]string, 0),
@ -214,10 +278,15 @@ func TestProxy(t *testing.T) {
httpEndpoint := "http://some-wrong-host/" httpEndpoint := "http://some-wrong-host/"
testSendRequest(t, httpEndpoint+"intercept-request", proxyClient, "intercept-request") testSendRequest(t, httpEndpoint+"intercept-request", proxyClient, "intercept-request")
}) })
// todo: fail t.Run("https can't", func(t *testing.T) {
t.Run("https", func(t *testing.T) {
httpsEndpoint := "https://some-wrong-host/" httpsEndpoint := "https://some-wrong-host/"
testSendRequest(t, httpsEndpoint+"intercept-request", proxyClient, "intercept-request") _, err := http.Get(httpsEndpoint + "intercept-request")
if err == nil {
t.Fatal("should have error")
}
if !strings.Contains(err.Error(), "dial tcp") {
t.Fatal("should get dial error, but got", err.Error())
}
}) })
}) })
@ -231,44 +300,185 @@ func TestProxy(t *testing.T) {
}) })
}) })
t.Run("test proxy when disable client keep alive", func(t *testing.T) { t.Run("test proxy when DisableKeepAlives", func(t *testing.T) {
proxyClient := getProxyClient() proxyClient := getProxyClient()
proxyClient.Transport.(*http.Transport).DisableKeepAlives = true proxyClient.Transport.(*http.Transport).DisableKeepAlives = true
// todo: fail
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
testSendRequest(t, httpEndpoint, proxyClient, "ok") testSendRequest(t, httpEndpoint, proxyClient, "ok")
}) })
// todo: fail
t.Run("https", func(t *testing.T) { t.Run("https", func(t *testing.T) {
testSendRequest(t, httpsEndpoint, proxyClient, "ok") testSendRequest(t, httpsEndpoint, proxyClient, "ok")
}) })
}) })
t.Run("test addon execute order", func(t *testing.T) { t.Run("should trigger disconnect functions when DisableKeepAlives", func(t *testing.T) {
proxyClient := getProxyClient() proxyClient := getProxyClient()
proxyClient.Transport.(*http.Transport).DisableKeepAlives = true proxyClient.Transport.(*http.Transport).DisableKeepAlives = true
// todo: fail
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
testOrderAddonInstance.orders = make([]string, 0) time.Sleep(time.Millisecond * 10)
testOrderAddonInstance.reset()
testSendRequest(t, httpEndpoint, proxyClient, "ok") testSendRequest(t, httpEndpoint, proxyClient, "ok")
wantOrders := []string{ time.Sleep(time.Millisecond * 10)
"ClientConnected", testOrderAddonInstance.contains(t, "ClientDisconnected")
"Requestheaders", testOrderAddonInstance.contains(t, "ServerDisconnected")
"Request", })
"StreamRequestModifier",
"ServerConnected", t.Run("https", func(t *testing.T) {
"Responseheaders", time.Sleep(time.Millisecond * 10)
"Response", testOrderAddonInstance.reset()
"StreamResponseModifier", testSendRequest(t, httpsEndpoint, proxyClient, "ok")
"ClientDisconnected", time.Sleep(time.Millisecond * 10)
"ServerDisconnected", testOrderAddonInstance.contains(t, "ClientDisconnected")
} testOrderAddonInstance.contains(t, "ServerDisconnected")
if !reflect.DeepEqual(testOrderAddonInstance.orders, wantOrders) { })
t.Fatalf("expected order %v, but got order %v", wantOrders, testOrderAddonInstance.orders) })
t.Run("should not have eof error when DisableKeepAlives", func(t *testing.T) {
proxyClient := getProxyClient()
proxyClient.Transport.(*http.Transport).DisableKeepAlives = true
t.Run("http", func(t *testing.T) {
for i := 0; i < 10; i++ {
testSendRequest(t, httpEndpoint, proxyClient, "ok")
}
})
t.Run("https", func(t *testing.T) {
for i := 0; i < 10; i++ {
testSendRequest(t, httpsEndpoint, proxyClient, "ok")
}
})
})
t.Run("should trigger disconnect functions when client side trigger off", func(t *testing.T) {
proxyClient := getProxyClient()
var clientConn net.Conn
proxyClient.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := (&net.Dialer{}).DialContext(ctx, network, addr)
clientConn = c
return c, err
}
t.Run("http", func(t *testing.T) {
time.Sleep(time.Millisecond * 10)
testOrderAddonInstance.reset()
testSendRequest(t, httpEndpoint, proxyClient, "ok")
clientConn.Close()
time.Sleep(time.Millisecond * 10)
testOrderAddonInstance.contains(t, "ClientDisconnected")
testOrderAddonInstance.contains(t, "ServerDisconnected")
testOrderAddonInstance.before(t, "ClientDisconnected", "ServerDisconnected")
})
t.Run("https", func(t *testing.T) {
time.Sleep(time.Millisecond * 10)
testOrderAddonInstance.reset()
testSendRequest(t, httpsEndpoint, proxyClient, "ok")
clientConn.Close()
time.Sleep(time.Millisecond * 10)
testOrderAddonInstance.contains(t, "ClientDisconnected")
testOrderAddonInstance.contains(t, "ServerDisconnected")
testOrderAddonInstance.before(t, "ClientDisconnected", "ServerDisconnected")
})
})
} }
func TestProxyWhenServerNotKeepAlive(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
})
server := &http.Server{
Handler: mux,
}
server.SetKeepAlivesEnabled(false)
// start http server
ln, err := net.Listen("tcp", "127.0.0.1:0")
handleError(t, err)
defer ln.Close()
go server.Serve(ln)
// start https server
tlsLn, err := net.Listen("tcp", "127.0.0.1:0")
handleError(t, err)
defer tlsLn.Close()
ca, err := cert.NewCAMemory()
handleError(t, err)
cert, err := ca.GetCert("localhost")
handleError(t, err)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{*cert},
}
go server.Serve(tls.NewListener(tlsLn, tlsConfig))
httpEndpoint := "http://" + ln.Addr().String() + "/"
httpsPort := tlsLn.Addr().(*net.TCPAddr).Port
httpsEndpoint := "https://localhost:" + strconv.Itoa(httpsPort) + "/"
// start proxy
testProxy, err := NewProxy(&Options{
Addr: ":29081", // some random port
SslInsecure: true,
})
handleError(t, err)
testProxy.AddAddon(&interceptAddon{})
testOrderAddonInstance := &testOrderAddon{
orders: make([]string, 0),
}
testProxy.AddAddon(testOrderAddonInstance)
go testProxy.Start()
time.Sleep(time.Millisecond * 10) // wait for test proxy startup
getProxyClient := func() *http.Client {
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
Proxy: func(r *http.Request) (*url.URL, error) {
return url.Parse("http://127.0.0.1:29081")
},
},
}
}
t.Run("should not have eof error when server side DisableKeepAlives", func(t *testing.T) {
proxyClient := getProxyClient()
t.Run("http", func(t *testing.T) {
for i := 0; i < 10; i++ {
testSendRequest(t, httpEndpoint, proxyClient, "ok")
}
})
t.Run("https", func(t *testing.T) {
for i := 0; i < 10; i++ {
testSendRequest(t, httpsEndpoint, proxyClient, "ok")
}
})
})
t.Run("should trigger disconnect functions when server DisableKeepAlives", func(t *testing.T) {
proxyClient := getProxyClient()
t.Run("http", func(t *testing.T) {
time.Sleep(time.Millisecond * 10)
testOrderAddonInstance.reset()
testSendRequest(t, httpEndpoint, proxyClient, "ok")
time.Sleep(time.Millisecond * 10)
testOrderAddonInstance.contains(t, "ClientDisconnected")
testOrderAddonInstance.contains(t, "ServerDisconnected")
testOrderAddonInstance.before(t, "ServerDisconnected", "ClientDisconnected")
})
t.Run("https", func(t *testing.T) {
time.Sleep(time.Millisecond * 10)
testOrderAddonInstance.reset()
testSendRequest(t, httpsEndpoint, proxyClient, "ok")
time.Sleep(time.Millisecond * 10)
testOrderAddonInstance.contains(t, "ClientDisconnected")
testOrderAddonInstance.contains(t, "ServerDisconnected")
testOrderAddonInstance.before(t, "ServerDisconnected", "ClientDisconnected")
}) })
}) })
} }

Loading…
Cancel
Save