From d95a8ccf2122901e4a838506cf5cc7f2d5e7be85 Mon Sep 17 00:00:00 2001 From: lqqyt2423 <974923609@qq.com> Date: Sat, 11 Jun 2022 13:20:12 +0800 Subject: [PATCH] add proxyListener --- connection/connection.go | 1 + flow/conncontext.go | 39 +++++++++++++--- proxy/middle.go | 12 ++--- proxy/proxy.go | 98 ++++++++++++++++++++++++---------------- 4 files changed, 98 insertions(+), 52 deletions(-) diff --git a/connection/connection.go b/connection/connection.go index ec29189..b50449a 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -23,6 +23,7 @@ func NewClient(c net.Conn) *Client { type Server struct { Id uuid.UUID + Conn net.Conn Client *http.Client } diff --git a/flow/conncontext.go b/flow/conncontext.go index b9f3819..1335864 100644 --- a/flow/conncontext.go +++ b/flow/conncontext.go @@ -1,10 +1,10 @@ package flow import ( + "context" "crypto/tls" "net" "net/http" - "time" "github.com/lqqyt2423/go-mitmproxy/connection" ) @@ -16,6 +16,22 @@ type ConnContext struct { Server *connection.Server } +func NewConnContext(c net.Conn) *ConnContext { + client := connection.NewClient(c) + return &ConnContext{ + Client: client, + } +} + +type serverConn struct { + net.Conn +} + +func (c *serverConn) Close() error { + log.Debugln("in http serverConn close") + return c.Conn.Close() +} + func (connCtx *ConnContext) InitHttpServer(SslInsecure bool) { if connCtx.Server != nil { return @@ -30,10 +46,19 @@ func (connCtx *ConnContext) InitHttpServer(SslInsecure bool) { Proxy: http.ProxyFromEnvironment, // todo: change here - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := (&net.Dialer{ + // Timeout: 30 * time.Second, + // KeepAlive: 30 * time.Second, + }).DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + cw := &serverConn{c} + server.Conn = cw + return cw, nil + }, ForceAttemptHTTP2: false, // disable http2 DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true. @@ -65,8 +90,8 @@ func (connCtx *ConnContext) InitHttpsServer(SslInsecure bool) { // todo: change here DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, + // Timeout: 30 * time.Second, + // KeepAlive: 30 * time.Second, }).DialContext, ForceAttemptHTTP2: false, // disable http2 diff --git a/proxy/middle.go b/proxy/middle.go index c00d866..7f9cc02 100644 --- a/proxy/middle.go +++ b/proxy/middle.go @@ -16,13 +16,13 @@ import ( // 模拟了标准库中 server 运行,目的是仅通过当前进程内存转发 socket 数据,不需要经过 tcp 或 unix socket // mock net.Listener -type listener struct { +type middleListener struct { connChan chan net.Conn } -func (l *listener) Accept() (net.Conn, error) { return <-l.connChan, nil } -func (l *listener) Close() error { return nil } -func (l *listener) Addr() net.Addr { return nil } +func (l *middleListener) Accept() (net.Conn, error) { return <-l.connChan, nil } +func (l *middleListener) Close() error { return nil } +func (l *middleListener) Addr() net.Addr { return nil } type pipeAddr struct { remoteAddr string @@ -106,7 +106,7 @@ func NewMiddle(proxy *Proxy, caPath string) (Interceptor, error) { } m.Server = server - m.Listener = &listener{make(chan net.Conn)} + m.Listener = &middleListener{make(chan net.Conn)} return m, nil } @@ -155,7 +155,7 @@ func (m *Middle) intercept(serverConn *connBuf) { // tls serverConn.connContext.Client.Tls = true serverConn.connContext.InitHttpsServer(m.Proxy.Opts.SslInsecure) - m.Listener.(*listener).connChan <- serverConn + m.Listener.(*middleListener).connChan <- serverConn } else { // ws DefaultWebSocket.WS(serverConn, serverConn.host) diff --git a/proxy/proxy.go b/proxy/proxy.go index 92c3cf6..8cad148 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -6,10 +6,9 @@ import ( "io" "net" "net/http" - "time" + "sync" "github.com/lqqyt2423/go-mitmproxy/addon" - "github.com/lqqyt2423/go-mitmproxy/connection" "github.com/lqqyt2423/go-mitmproxy/flow" _log "github.com/sirupsen/logrus" ) @@ -30,8 +29,42 @@ type Proxy struct { Server *http.Server Interceptor Interceptor Addons []addon.Addon +} + +type proxyListener struct { + net.Listener + proxy *Proxy +} + +func (l *proxyListener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + return &proxyConn{ + Conn: c, + proxy: l.proxy, + }, nil +} - activeConn map[net.Conn]*flow.ConnContext +type proxyConn struct { + net.Conn + proxy *Proxy + connCtx *flow.ConnContext + closeOnce sync.Once +} + +func (c *proxyConn) Close() error { + log.Debugln("in proxyConn close") + + c.closeOnce.Do(func() { + for _, addon := range c.proxy.Addons { + addon.ClientDisconnected(c.connCtx.Client) + } + }) + + return c.Conn.Close() } func NewProxy(opts *Options) (*Proxy, error) { @@ -40,29 +73,18 @@ func NewProxy(opts *Options) (*Proxy, error) { proxy.Version = "0.2.0" proxy.Server = &http.Server{ - Addr: opts.Addr, - Handler: proxy, - IdleTimeout: 5 * time.Second, + Addr: opts.Addr, + Handler: proxy, + // IdleTimeout: 5 * time.Second, ConnContext: func(ctx context.Context, c net.Conn) context.Context { - client := connection.NewClient(c) - connCtx := &flow.ConnContext{ - Client: client, + connCtx := flow.NewConnContext(c) + for _, addon := range proxy.Addons { + addon.ClientConnected(connCtx.Client) } - proxy.activeConn[c] = connCtx + c.(*proxyConn).connCtx = connCtx return context.WithValue(ctx, flow.ConnContextKey, connCtx) }, - - ConnState: func(c net.Conn, cs http.ConnState) { - if cs == http.StateNew { - client := proxy.activeConn[c].Client - for _, addon := range proxy.Addons { - addon.ClientConnected(client) - } - } else if cs == http.StateClosed { - proxy.whenClientConnClose(c) - } - }, } interceptor, err := NewMiddle(proxy, opts.CaRootPath) @@ -77,8 +99,6 @@ func NewProxy(opts *Options) (*Proxy, error) { proxy.Addons = make([]addon.Addon, 0) - proxy.activeConn = make(map[net.Conn]*flow.ConnContext) - return proxy, nil } @@ -91,7 +111,20 @@ func (proxy *Proxy) Start() error { go func() { log.Infof("Proxy start listen at %v\n", proxy.Server.Addr) - err := proxy.Server.ListenAndServe() + addr := proxy.Server.Addr + if addr == "" { + addr = ":http" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + errChan <- err + return + } + pln := &proxyListener{ + Listener: ln, + proxy: proxy, + } + err = proxy.Server.Serve(pln) errChan <- err }() @@ -286,11 +319,10 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { return } - cconn.(*net.TCPConn).SetLinger(0) // send RST other than FIN when finished, to avoid TIME_WAIT state - cconn.(*net.TCPConn).SetKeepAlive(false) + // cconn.(*net.TCPConn).SetLinger(0) // send RST other than FIN when finished, to avoid TIME_WAIT state + // cconn.(*net.TCPConn).SetKeepAlive(false) defer func() { cconn.Close() - proxy.whenClientConnClose(cconn) }() _, err = io.WriteString(cconn, "HTTP/1.1 200 Connection Established\r\n\r\n") @@ -301,15 +333,3 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { Transfer(log, conn, cconn) } - -func (proxy *Proxy) whenClientConnClose(c net.Conn) { - connCtx := proxy.activeConn[c] - - for _, addon := range proxy.Addons { - addon.ClientDisconnected(connCtx.Client) - } - - connCtx.Server.Client.CloseIdleConnections() - - delete(proxy.activeConn, c) -}