diff --git a/.gitignore b/.gitignore index a41b35c..c016ae6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ /dummycert /.idea dist/ +.vscode diff --git a/proxy/connection.go b/proxy/connection.go index c9466cb..b433b8f 100644 --- a/proxy/connection.go +++ b/proxy/connection.go @@ -75,8 +75,9 @@ type ConnContext struct { ClientConn *ClientConn `json:"clientConn"` ServerConn *ServerConn `json:"serverConn"` - proxy *Proxy - pipeConn *pipeConn + proxy *Proxy + pipeConn *pipeConn + closeAfterResponse bool // after http response, http server will close the connection } func newConnContext(c net.Conn, proxy *Proxy) *ConnContext { @@ -231,10 +232,10 @@ type wrapClientConn struct { } func (c *wrapClientConn) Close() error { - log.Debugln("in wrapClientConn close") if c.closed { return c.closeErr } + log.Debugln("in wrapClientConn close") c.closed = true c.closeErr = c.Conn.Close() @@ -244,7 +245,7 @@ func (c *wrapClientConn) Close() error { } if c.connCtx.ServerConn != nil && c.connCtx.ServerConn.Conn != nil { - c.connCtx.ServerConn.Conn.Close() + c.connCtx.ServerConn.Conn.(*wrapServerConn).Conn.(*net.TCPConn).CloseRead() } return c.closeErr @@ -278,10 +279,10 @@ type wrapServerConn struct { } func (c *wrapServerConn) Close() error { - log.Debugln("in wrapServerConn close") if c.closed { return c.closeErr } + log.Debugln("in wrapServerConn close") c.closed = true c.closeErr = c.Conn.Close() @@ -290,7 +291,14 @@ func (c *wrapServerConn) Close() error { addon.ServerDisconnected(c.connCtx) } - c.connCtx.ClientConn.Conn.Close() + if !c.connCtx.ClientConn.Tls { + c.connCtx.ClientConn.Conn.(*wrapClientConn).Conn.(*net.TCPConn).CloseRead() + } else { + // if keep-alive connection close + if !c.connCtx.closeAfterResponse { + c.connCtx.pipeConn.Close() + } + } return c.closeErr } diff --git a/proxy/flow.go b/proxy/flow.go index 2c09c87..b92b6e9 100644 --- a/proxy/flow.go +++ b/proxy/flow.go @@ -99,6 +99,8 @@ type Response struct { Body []byte `json:"-"` BodyReader io.Reader + close bool // connection close + decodedBody []byte decoded bool // decoded reports whether the response was sent compressed but was decoded to decodedBody. decodedErr error diff --git a/proxy/helper.go b/proxy/helper.go index 7555fb7..c6d39eb 100644 --- a/proxy/helper.go +++ b/proxy/helper.go @@ -3,6 +3,7 @@ package proxy import ( "bytes" "io" + "net" "os" "strings" "sync" @@ -38,28 +39,39 @@ func logErr(log *log.Entry, err error) (loged bool) { } // 转发流量 -// Read a => Write b -// Read b => Write a -func transfer(log *log.Entry, a, b io.ReadWriteCloser) { +func transfer(log *log.Entry, server, client io.ReadWriteCloser) { done := make(chan struct{}) defer close(done) - forward := func(dst io.WriteCloser, src io.Reader, ec chan<- error) { - _, err := io.Copy(dst, src) - - dst.Close() // 当一端读结束时,结束另一端的写 - + errChan := make(chan error) + go func() { + _, err := io.Copy(server, client) + log.Debugln("client copy end", err) + client.Close() select { case <-done: return - case ec <- err: + case errChan <- err: return } - } + }() + go func() { + _, err := io.Copy(client, server) + log.Debugln("server copy end", err) + server.Close() + + if clientConn, ok := client.(*wrapClientConn); ok { + err := clientConn.Conn.(*net.TCPConn).CloseRead() + log.Debugln("clientConn.Conn.(*net.TCPConn).CloseRead()", err) + } - errChan := make(chan error) - go forward(a, b, errChan) - go forward(b, a, errChan) + select { + case <-done: + return + case errChan <- err: + return + } + }() for i := 0; i < 2; i++ { if err := <-errChan; err != nil { diff --git a/proxy/proxy.go b/proxy/proxy.go index 1aacf83..13f5d80 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -124,6 +124,9 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { } } } + if response.close { + res.Header().Add("Connection", "close") + } res.WriteHeader(response.StatusCode) if body != nil { @@ -219,11 +222,17 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { res.WriteHeader(502) return } + + if proxyRes.Close { + f.ConnContext.closeAfterResponse = true + } + defer proxyRes.Body.Close() f.Response = &Response{ StatusCode: proxyRes.StatusCode, Header: proxyRes.Header, + close: proxyRes.Close, } // trigger addon event Responseheaders diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index a1fc920..b43055f 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1,15 +1,16 @@ package proxy import ( + "context" "crypto/tls" "io" "io/ioutil" "net" "net/http" "net/url" - "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -68,40 +69,104 @@ func (addon *interceptAddon) Response(f *Flow) { type testOrderAddon struct { BaseAddon orders []string + mu sync.Mutex +} + +func (addon *testOrderAddon) reset() { + addon.mu.Lock() + defer addon.mu.Unlock() + addon.orders = make([]string, 0) +} + +func (addon *testOrderAddon) contains(t *testing.T, name string) { + t.Helper() + addon.mu.Lock() + defer addon.mu.Unlock() + for _, n := range addon.orders { + if name == n { + return + } + } + t.Fatalf("expected contains %s, but not", name) +} + +func (addon *testOrderAddon) before(t *testing.T, a, b string) { + t.Helper() + addon.mu.Lock() + defer addon.mu.Unlock() + aIndex, bIndex := -1, -1 + for i, n := range addon.orders { + if a == n { + aIndex = i + } else if b == n { + bIndex = i + } + } + if aIndex == -1 { + t.Fatalf("expected contains %s, but not", a) + } + if bIndex == -1 { + t.Fatalf("expected contains %s, but not", b) + } + if aIndex > bIndex { + t.Fatalf("expected %s executed before %s, but not", a, b) + } } func (addon *testOrderAddon) ClientConnected(*ClientConn) { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "ClientConnected") } func (addon *testOrderAddon) ClientDisconnected(*ClientConn) { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "ClientDisconnected") } func (addon *testOrderAddon) ServerConnected(*ConnContext) { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "ServerConnected") } func (addon *testOrderAddon) ServerDisconnected(*ConnContext) { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "ServerDisconnected") } func (addon *testOrderAddon) TlsEstablishedServer(*ConnContext) { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "TlsEstablishedServer") } func (addon *testOrderAddon) Requestheaders(*Flow) { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "Requestheaders") } func (addon *testOrderAddon) Request(*Flow) { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "Request") } func (addon *testOrderAddon) Responseheaders(*Flow) { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "Responseheaders") } func (addon *testOrderAddon) Response(*Flow) { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "Response") } func (addon *testOrderAddon) StreamRequestModifier(f *Flow, in io.Reader) io.Reader { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "StreamRequestModifier") return in } func (addon *testOrderAddon) StreamResponseModifier(f *Flow, in io.Reader) io.Reader { + addon.mu.Lock() + defer addon.mu.Unlock() addon.orders = append(addon.orders, "StreamResponseModifier") return in } @@ -167,7 +232,6 @@ func TestProxy(t *testing.T) { SslInsecure: true, }) handleError(t, err) - testProxy.AddAddon(&LogAddon{}) testProxy.AddAddon(&interceptAddon{}) testOrderAddonInstance := &testOrderAddon{ orders: make([]string, 0), @@ -214,10 +278,15 @@ func TestProxy(t *testing.T) { httpEndpoint := "http://some-wrong-host/" testSendRequest(t, httpEndpoint+"intercept-request", proxyClient, "intercept-request") }) - // todo: fail - t.Run("https", func(t *testing.T) { + t.Run("https can't", func(t *testing.T) { httpsEndpoint := "https://some-wrong-host/" - testSendRequest(t, httpsEndpoint+"intercept-request", proxyClient, "intercept-request") + _, err := http.Get(httpsEndpoint + "intercept-request") + if err == nil { + t.Fatal("should have error") + } + if !strings.Contains(err.Error(), "dial tcp") { + t.Fatal("should get dial error, but got", err.Error()) + } }) }) @@ -231,44 +300,185 @@ func TestProxy(t *testing.T) { }) }) - t.Run("test proxy when disable client keep alive", func(t *testing.T) { + t.Run("test proxy when DisableKeepAlives", func(t *testing.T) { + proxyClient := getProxyClient() + proxyClient.Transport.(*http.Transport).DisableKeepAlives = true + + t.Run("http", func(t *testing.T) { + testSendRequest(t, httpEndpoint, proxyClient, "ok") + }) + + t.Run("https", func(t *testing.T) { + testSendRequest(t, httpsEndpoint, proxyClient, "ok") + }) + }) + + t.Run("should trigger disconnect functions when DisableKeepAlives", func(t *testing.T) { proxyClient := getProxyClient() proxyClient.Transport.(*http.Transport).DisableKeepAlives = true - // todo: fail t.Run("http", func(t *testing.T) { + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.reset() testSendRequest(t, httpEndpoint, proxyClient, "ok") + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.contains(t, "ClientDisconnected") + testOrderAddonInstance.contains(t, "ServerDisconnected") }) - // todo: fail t.Run("https", func(t *testing.T) { + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.reset() testSendRequest(t, httpsEndpoint, proxyClient, "ok") + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.contains(t, "ClientDisconnected") + testOrderAddonInstance.contains(t, "ServerDisconnected") }) }) - t.Run("test addon execute order", func(t *testing.T) { + t.Run("should not have eof error when DisableKeepAlives", func(t *testing.T) { proxyClient := getProxyClient() proxyClient.Transport.(*http.Transport).DisableKeepAlives = true + t.Run("http", func(t *testing.T) { + for i := 0; i < 10; i++ { + testSendRequest(t, httpEndpoint, proxyClient, "ok") + } + }) + t.Run("https", func(t *testing.T) { + for i := 0; i < 10; i++ { + testSendRequest(t, httpsEndpoint, proxyClient, "ok") + } + }) + }) + + t.Run("should trigger disconnect functions when client side trigger off", func(t *testing.T) { + proxyClient := getProxyClient() + var clientConn net.Conn + proxyClient.Transport.(*http.Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := (&net.Dialer{}).DialContext(ctx, network, addr) + clientConn = c + return c, err + } - // todo: fail t.Run("http", func(t *testing.T) { - testOrderAddonInstance.orders = make([]string, 0) + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.reset() testSendRequest(t, httpEndpoint, proxyClient, "ok") - wantOrders := []string{ - "ClientConnected", - "Requestheaders", - "Request", - "StreamRequestModifier", - "ServerConnected", - "Responseheaders", - "Response", - "StreamResponseModifier", - "ClientDisconnected", - "ServerDisconnected", + clientConn.Close() + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.contains(t, "ClientDisconnected") + testOrderAddonInstance.contains(t, "ServerDisconnected") + testOrderAddonInstance.before(t, "ClientDisconnected", "ServerDisconnected") + }) + + t.Run("https", func(t *testing.T) { + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.reset() + testSendRequest(t, httpsEndpoint, proxyClient, "ok") + clientConn.Close() + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.contains(t, "ClientDisconnected") + testOrderAddonInstance.contains(t, "ServerDisconnected") + testOrderAddonInstance.before(t, "ClientDisconnected", "ServerDisconnected") + }) + }) +} + +func TestProxyWhenServerNotKeepAlive(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + }) + server := &http.Server{ + Handler: mux, + } + server.SetKeepAlivesEnabled(false) + + // start http server + ln, err := net.Listen("tcp", "127.0.0.1:0") + handleError(t, err) + defer ln.Close() + go server.Serve(ln) + + // start https server + tlsLn, err := net.Listen("tcp", "127.0.0.1:0") + handleError(t, err) + defer tlsLn.Close() + ca, err := cert.NewCAMemory() + handleError(t, err) + cert, err := ca.GetCert("localhost") + handleError(t, err) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{*cert}, + } + go server.Serve(tls.NewListener(tlsLn, tlsConfig)) + + httpEndpoint := "http://" + ln.Addr().String() + "/" + httpsPort := tlsLn.Addr().(*net.TCPAddr).Port + httpsEndpoint := "https://localhost:" + strconv.Itoa(httpsPort) + "/" + + // start proxy + testProxy, err := NewProxy(&Options{ + Addr: ":29081", // some random port + SslInsecure: true, + }) + handleError(t, err) + testProxy.AddAddon(&interceptAddon{}) + testOrderAddonInstance := &testOrderAddon{ + orders: make([]string, 0), + } + testProxy.AddAddon(testOrderAddonInstance) + go testProxy.Start() + time.Sleep(time.Millisecond * 10) // wait for test proxy startup + + getProxyClient := func() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Proxy: func(r *http.Request) (*url.URL, error) { + return url.Parse("http://127.0.0.1:29081") + }, + }, + } + } + + t.Run("should not have eof error when server side DisableKeepAlives", func(t *testing.T) { + proxyClient := getProxyClient() + t.Run("http", func(t *testing.T) { + for i := 0; i < 10; i++ { + testSendRequest(t, httpEndpoint, proxyClient, "ok") } - if !reflect.DeepEqual(testOrderAddonInstance.orders, wantOrders) { - t.Fatalf("expected order %v, but got order %v", wantOrders, testOrderAddonInstance.orders) + }) + t.Run("https", func(t *testing.T) { + for i := 0; i < 10; i++ { + testSendRequest(t, httpsEndpoint, proxyClient, "ok") } }) }) + + t.Run("should trigger disconnect functions when server DisableKeepAlives", func(t *testing.T) { + proxyClient := getProxyClient() + + t.Run("http", func(t *testing.T) { + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.reset() + testSendRequest(t, httpEndpoint, proxyClient, "ok") + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.contains(t, "ClientDisconnected") + testOrderAddonInstance.contains(t, "ServerDisconnected") + testOrderAddonInstance.before(t, "ServerDisconnected", "ClientDisconnected") + }) + + t.Run("https", func(t *testing.T) { + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.reset() + testSendRequest(t, httpsEndpoint, proxyClient, "ok") + time.Sleep(time.Millisecond * 10) + testOrderAddonInstance.contains(t, "ClientDisconnected") + testOrderAddonInstance.contains(t, "ServerDisconnected") + testOrderAddonInstance.before(t, "ServerDisconnected", "ClientDisconnected") + }) + }) }