diff --git a/proxy/interceptor.go b/proxy/interceptor.go index bb546e1..7fef80d 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -2,14 +2,15 @@ package proxy import ( "net" + "net/http" ) // 拦截 https 流量通用接口 type Interceptor interface { // 初始化 Start() error - // 针对每个 host 连接 - Dial(host string) (net.Conn, error) + // 传入当前客户端 req + Dial(req *http.Request) (net.Conn, error) } // 直接转发 https 流量 @@ -19,6 +20,6 @@ func (i *Forward) Start() error { return nil } -func (i *Forward) Dial(host string) (net.Conn, error) { - return net.Dial("tcp", host) +func (i *Forward) Dial(req *http.Request) (net.Conn, error) { + return net.Dial("tcp", req.Host) } diff --git a/proxy/middle.go b/proxy/middle.go index 55984ef..26f656b 100644 --- a/proxy/middle.go +++ b/proxy/middle.go @@ -21,25 +21,34 @@ func (l *listener) Accept() (net.Conn, error) { return <-l.connChan, nil } func (l *listener) Close() error { return nil } func (l *listener) Addr() net.Addr { return nil } +type pipeAddr struct { + remoteAddr string +} + +func (pipeAddr) Network() string { return "pipe" } +func (a *pipeAddr) String() string { return a.remoteAddr } + // 建立客户端和服务端通信的通道 -func newPipes(host string) (net.Conn, *connBuf) { +func newPipes(req *http.Request) (net.Conn, *connBuf) { client, srv := net.Pipe() - server := newConnBuf(srv, host) + server := newConnBuf(srv, req) return client, server } // add Peek method for conn type connBuf struct { net.Conn - r *bufio.Reader - host string + r *bufio.Reader + host string + remoteAddr string } -func newConnBuf(c net.Conn, host string) *connBuf { +func newConnBuf(c net.Conn, req *http.Request) *connBuf { return &connBuf{ - Conn: c, - r: bufio.NewReader(c), - host: host, + Conn: c, + r: bufio.NewReader(c), + host: req.Host, + remoteAddr: req.RemoteAddr, } } @@ -51,6 +60,10 @@ func (b *connBuf) Read(data []byte) (int, error) { return b.r.Read(data) } +func (b *connBuf) RemoteAddr() net.Addr { + return &pipeAddr{remoteAddr: b.remoteAddr} +} + // Middle: man-in-the-middle type Middle struct { Proxy *Proxy @@ -91,8 +104,8 @@ func (m *Middle) Start() error { return m.Server.ServeTLS(m.Listener, "", "") } -func (m *Middle) Dial(host string) (net.Conn, error) { - clientConn, serverConn := newPipes(host) +func (m *Middle) Dial(req *http.Request) (net.Conn, error) { + clientConn, serverConn := newPipes(req) go m.intercept(serverConn) return clientConn, nil } diff --git a/proxy/proxy.go b/proxy/proxy.go index bb9a8a1..56761f6 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -262,7 +262,7 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { log.Debug("receive connect") - conn, err := proxy.Interceptor.Dial(req.Host) + conn, err := proxy.Interceptor.Dial(req) if err != nil { log.Error(err) res.WriteHeader(502)