add proxy.Close and proxy.Shutdown

addon-dailer
lqqyt2423 2 years ago
parent 750c013fb4
commit 914c3674d6

@ -12,25 +12,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// 拦截 https 流量通用接口
type interceptor interface {
// 初始化
Start() error
// 传入当前客户端 req
Dial(req *http.Request) (net.Conn, error)
}
// 直接转发 https 流量
type forward struct{}
func (i *forward) Start() error {
return nil
}
func (i *forward) Dial(req *http.Request) (net.Conn, error) {
return net.Dial("tcp", req.Host)
}
// 模拟了标准库中 server 运行,目的是仅通过当前进程内存转发 socket 数据,不需要经过 tcp 或 unix socket // 模拟了标准库中 server 运行,目的是仅通过当前进程内存转发 socket 数据,不需要经过 tcp 或 unix socket
type pipeAddr struct { type pipeAddr struct {
@ -84,9 +65,17 @@ func newPipes(req *http.Request) (net.Conn, *pipeConn) {
// mock net.Listener // mock net.Listener
type middleListener struct { type middleListener struct {
connChan chan net.Conn connChan chan net.Conn
doneChan chan struct{}
} }
func (l *middleListener) Accept() (net.Conn, error) { return <-l.connChan, nil } func (l *middleListener) Accept() (net.Conn, error) {
select {
case c := <-l.connChan:
return c, nil
case <-l.doneChan:
return nil, http.ErrServerClosed
}
}
func (l *middleListener) Close() error { return nil } func (l *middleListener) Close() error { return nil }
func (l *middleListener) Addr() net.Addr { return nil } func (l *middleListener) Addr() net.Addr { return nil }
@ -98,7 +87,7 @@ type middle struct {
server *http.Server server *http.Server
} }
func newMiddle(proxy *Proxy) (interceptor, error) { func newMiddle(proxy *Proxy) (*middle, error) {
ca, err := cert.NewCA(proxy.Opts.CaRootPath) ca, err := cert.NewCA(proxy.Opts.CaRootPath)
if err != nil { if err != nil {
return nil, err return nil, err
@ -109,6 +98,7 @@ func newMiddle(proxy *Proxy) (interceptor, error) {
ca: ca, ca: ca,
listener: &middleListener{ listener: &middleListener{
connChan: make(chan net.Conn), connChan: make(chan net.Conn),
doneChan: make(chan struct{}),
}, },
} }
@ -138,11 +128,17 @@ func newMiddle(proxy *Proxy) (interceptor, error) {
return m, nil return m, nil
} }
func (m *middle) Start() error { func (m *middle) start() error {
return m.server.ServeTLS(m.listener, "", "") return m.server.ServeTLS(m.listener, "", "")
} }
func (m *middle) Dial(req *http.Request) (net.Conn, error) { func (m *middle) close() error {
err := m.server.Close()
close(m.listener.doneChan)
return err
}
func (m *middle) dial(req *http.Request) (net.Conn, error) {
pipeClientConn, pipeServerConn := newPipes(req) pipeClientConn, pipeServerConn := newPipes(req)
err := pipeServerConn.connContext.initServerTcpConn() err := pipeServerConn.connContext.initServerTcpConn()
if err != nil { if err != nil {

@ -24,7 +24,7 @@ type Proxy struct {
Addons []Addon Addons []Addon
server *http.Server server *http.Server
interceptor interceptor interceptor *middle
} }
func NewProxy(opts *Options) (*Proxy, error) { func NewProxy(opts *Options) (*Proxy, error) {
@ -65,33 +65,34 @@ func (proxy *Proxy) AddAddon(addon Addon) {
} }
func (proxy *Proxy) Start() error { func (proxy *Proxy) Start() error {
errChan := make(chan error)
go func() {
log.Infof("Proxy start listen at %v\n", proxy.server.Addr)
addr := proxy.server.Addr addr := proxy.server.Addr
if addr == "" { if addr == "" {
addr = ":http" addr = ":http"
} }
ln, err := net.Listen("tcp", addr) ln, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
errChan <- err return err
return
} }
go proxy.interceptor.start()
log.Infof("Proxy start listen at %v\n", proxy.server.Addr)
pln := &wrapListener{ pln := &wrapListener{
Listener: ln, Listener: ln,
proxy: proxy, proxy: proxy,
} }
err = proxy.server.Serve(pln) return proxy.server.Serve(pln)
errChan <- err }
}()
go func() { func (proxy *Proxy) Close() error {
err := proxy.interceptor.Start() err := proxy.server.Close()
errChan <- err proxy.interceptor.close()
}() return err
}
err := <-errChan func (proxy *Proxy) Shutdown(ctx context.Context) error {
err := proxy.server.Shutdown(ctx)
proxy.interceptor.close()
return err return err
} }
@ -279,7 +280,7 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) {
"host": req.Host, "host": req.Host,
}) })
conn, err := proxy.interceptor.Dial(req) conn, err := proxy.interceptor.dial(req)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
res.WriteHeader(502) res.WriteHeader(502)

@ -552,3 +552,85 @@ func TestProxyWhenServerKeepAliveButCloseImmediately(t *testing.T) {
}) })
}) })
} }
func TestProxyClose(t *testing.T) {
helper := &testProxyHelper{
server: &http.Server{},
proxyAddr: ":29083",
}
helper.init(t)
httpEndpoint := helper.httpEndpoint
httpsEndpoint := helper.httpsEndpoint
testProxy := helper.testProxy
getProxyClient := helper.getProxyClient
defer helper.ln.Close()
go helper.server.Serve(helper.ln)
defer helper.tlsPlainLn.Close()
go helper.server.Serve(helper.tlsLn)
errCh := make(chan error)
go func() {
err := testProxy.Start()
errCh <- err
}()
time.Sleep(time.Millisecond * 10) // wait for test proxy startup
proxyClient := getProxyClient()
testSendRequest(t, httpEndpoint, proxyClient, "ok")
testSendRequest(t, httpsEndpoint, proxyClient, "ok")
if err := testProxy.Close(); err != nil {
t.Fatalf("close got error %v", err)
}
select {
case err := <-errCh:
if err != http.ErrServerClosed {
t.Fatalf("expected ErrServerClosed error, but got %v", err)
}
case <-time.After(time.Millisecond * 10):
t.Fatal("close timeout")
}
}
func TestProxyShutdown(t *testing.T) {
helper := &testProxyHelper{
server: &http.Server{},
proxyAddr: ":29084",
}
helper.init(t)
httpEndpoint := helper.httpEndpoint
httpsEndpoint := helper.httpsEndpoint
testProxy := helper.testProxy
getProxyClient := helper.getProxyClient
defer helper.ln.Close()
go helper.server.Serve(helper.ln)
defer helper.tlsPlainLn.Close()
go helper.server.Serve(helper.tlsLn)
errCh := make(chan error)
go func() {
err := testProxy.Start()
errCh <- err
}()
time.Sleep(time.Millisecond * 10) // wait for test proxy startup
proxyClient := getProxyClient()
testSendRequest(t, httpEndpoint, proxyClient, "ok")
testSendRequest(t, httpsEndpoint, proxyClient, "ok")
if err := testProxy.Shutdown(context.TODO()); err != nil {
t.Fatalf("shutdown got error %v", err)
}
select {
case err := <-errCh:
if err != http.ErrServerClosed {
t.Fatalf("expected ErrServerClosed error, but got %v", err)
}
case <-time.After(time.Millisecond * 10):
t.Fatal("shutdown timeout")
}
}

Loading…
Cancel
Save