diff --git a/proxy/interceptor.go b/proxy/interceptor.go index 2e95f2d..60e5130 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -12,25 +12,6 @@ import ( log "github.com/sirupsen/logrus" ) -// 拦截 https 流量通用接口 -type interceptor interface { - // 初始化 - Start() error - // 传入当前客户端 req - Dial(req *http.Request) (net.Conn, error) -} - -// 直接转发 https 流量 -type forward struct{} - -func (i *forward) Start() error { - return nil -} - -func (i *forward) Dial(req *http.Request) (net.Conn, error) { - return net.Dial("tcp", req.Host) -} - // 模拟了标准库中 server 运行,目的是仅通过当前进程内存转发 socket 数据,不需要经过 tcp 或 unix socket type pipeAddr struct { @@ -84,11 +65,19 @@ func newPipes(req *http.Request) (net.Conn, *pipeConn) { // mock net.Listener type middleListener struct { connChan chan net.Conn + doneChan chan struct{} } -func (l *middleListener) Accept() (net.Conn, error) { return <-l.connChan, nil } -func (l *middleListener) Close() error { return nil } -func (l *middleListener) Addr() net.Addr { return nil } +func (l *middleListener) Accept() (net.Conn, error) { + select { + case c := <-l.connChan: + return c, nil + case <-l.doneChan: + return nil, http.ErrServerClosed + } +} +func (l *middleListener) Close() error { return nil } +func (l *middleListener) Addr() net.Addr { return nil } // middle: man-in-the-middle server type middle struct { @@ -98,7 +87,7 @@ type middle struct { server *http.Server } -func newMiddle(proxy *Proxy) (interceptor, error) { +func newMiddle(proxy *Proxy) (*middle, error) { ca, err := cert.NewCA(proxy.Opts.CaRootPath) if err != nil { return nil, err @@ -109,6 +98,7 @@ func newMiddle(proxy *Proxy) (interceptor, error) { ca: ca, listener: &middleListener{ connChan: make(chan net.Conn), + doneChan: make(chan struct{}), }, } @@ -138,11 +128,17 @@ func newMiddle(proxy *Proxy) (interceptor, error) { return m, nil } -func (m *middle) Start() error { +func (m *middle) start() error { return m.server.ServeTLS(m.listener, "", "") } -func (m *middle) Dial(req *http.Request) (net.Conn, error) { +func (m *middle) close() error { + err := m.server.Close() + close(m.listener.doneChan) + return err +} + +func (m *middle) dial(req *http.Request) (net.Conn, error) { pipeClientConn, pipeServerConn := newPipes(req) err := pipeServerConn.connContext.initServerTcpConn() if err != nil { diff --git a/proxy/proxy.go b/proxy/proxy.go index 3ccef32..adeba1c 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -24,7 +24,7 @@ type Proxy struct { Addons []Addon server *http.Server - interceptor interceptor + interceptor *middle } func NewProxy(opts *Options) (*Proxy, error) { @@ -65,33 +65,34 @@ func (proxy *Proxy) AddAddon(addon Addon) { } func (proxy *Proxy) Start() error { - errChan := make(chan error) + addr := proxy.server.Addr + if addr == "" { + addr = ":http" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } - go func() { - log.Infof("Proxy start listen at %v\n", proxy.server.Addr) - addr := proxy.server.Addr - if addr == "" { - addr = ":http" - } - ln, err := net.Listen("tcp", addr) - if err != nil { - errChan <- err - return - } - pln := &wrapListener{ - Listener: ln, - proxy: proxy, - } - err = proxy.server.Serve(pln) - errChan <- err - }() + go proxy.interceptor.start() - go func() { - err := proxy.interceptor.Start() - errChan <- err - }() + log.Infof("Proxy start listen at %v\n", proxy.server.Addr) + pln := &wrapListener{ + Listener: ln, + proxy: proxy, + } + return proxy.server.Serve(pln) +} + +func (proxy *Proxy) Close() error { + err := proxy.server.Close() + proxy.interceptor.close() + return err +} - err := <-errChan +func (proxy *Proxy) Shutdown(ctx context.Context) error { + err := proxy.server.Shutdown(ctx) + proxy.interceptor.close() return err } @@ -279,7 +280,7 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { "host": req.Host, }) - conn, err := proxy.interceptor.Dial(req) + conn, err := proxy.interceptor.dial(req) if err != nil { log.Error(err) res.WriteHeader(502) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 60d71ed..5d4f101 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -552,3 +552,85 @@ func TestProxyWhenServerKeepAliveButCloseImmediately(t *testing.T) { }) }) } + +func TestProxyClose(t *testing.T) { + helper := &testProxyHelper{ + server: &http.Server{}, + proxyAddr: ":29083", + } + helper.init(t) + httpEndpoint := helper.httpEndpoint + httpsEndpoint := helper.httpsEndpoint + testProxy := helper.testProxy + getProxyClient := helper.getProxyClient + defer helper.ln.Close() + go helper.server.Serve(helper.ln) + defer helper.tlsPlainLn.Close() + go helper.server.Serve(helper.tlsLn) + + errCh := make(chan error) + go func() { + err := testProxy.Start() + errCh <- err + }() + + time.Sleep(time.Millisecond * 10) // wait for test proxy startup + + proxyClient := getProxyClient() + testSendRequest(t, httpEndpoint, proxyClient, "ok") + testSendRequest(t, httpsEndpoint, proxyClient, "ok") + + if err := testProxy.Close(); err != nil { + t.Fatalf("close got error %v", err) + } + + select { + case err := <-errCh: + if err != http.ErrServerClosed { + t.Fatalf("expected ErrServerClosed error, but got %v", err) + } + case <-time.After(time.Millisecond * 10): + t.Fatal("close timeout") + } +} + +func TestProxyShutdown(t *testing.T) { + helper := &testProxyHelper{ + server: &http.Server{}, + proxyAddr: ":29084", + } + helper.init(t) + httpEndpoint := helper.httpEndpoint + httpsEndpoint := helper.httpsEndpoint + testProxy := helper.testProxy + getProxyClient := helper.getProxyClient + defer helper.ln.Close() + go helper.server.Serve(helper.ln) + defer helper.tlsPlainLn.Close() + go helper.server.Serve(helper.tlsLn) + + errCh := make(chan error) + go func() { + err := testProxy.Start() + errCh <- err + }() + + time.Sleep(time.Millisecond * 10) // wait for test proxy startup + + proxyClient := getProxyClient() + testSendRequest(t, httpEndpoint, proxyClient, "ok") + testSendRequest(t, httpsEndpoint, proxyClient, "ok") + + if err := testProxy.Shutdown(context.TODO()); err != nil { + t.Fatalf("shutdown got error %v", err) + } + + select { + case err := <-errCh: + if err != http.ErrServerClosed { + t.Fatalf("expected ErrServerClosed error, but got %v", err) + } + case <-time.After(time.Millisecond * 10): + t.Fatal("shutdown timeout") + } +}