From 7f55903797a1dc87c241c967863a6382849edb0a Mon Sep 17 00:00:00 2001 From: lqqyt2423 <974923609@qq.com> Date: Fri, 17 Jun 2022 14:29:55 +0800 Subject: [PATCH] code refactoring --- cert/cert.go | 4 +- proxy/helper.go | 2 +- proxy/interceptor.go | 152 ++++++++++++++++++++++++++++++++++++++++ proxy/middle.go | 160 ------------------------------------------- proxy/proxy.go | 4 -- 5 files changed, 155 insertions(+), 167 deletions(-) delete mode 100644 proxy/middle.go diff --git a/cert/cert.go b/cert/cert.go index 9929abf..e097fdd 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -277,7 +277,7 @@ func (ca *CA) GetCert(commonName string) (*tls.Certificate, error) { ca.cacheMu.Lock() if val, ok := ca.cache.Get(commonName); ok { ca.cacheMu.Unlock() - log.WithField("commonName", commonName).Debug("GetCert") + log.Debugf("ca GetCert: %v", commonName) return val.(*tls.Certificate), nil } ca.cacheMu.Unlock() @@ -301,7 +301,7 @@ func (ca *CA) GetCert(commonName string) (*tls.Certificate, error) { // TODO: 是否应该支持多个 SubjectAltName func (ca *CA) DummyCert(commonName string) (*tls.Certificate, error) { - log.WithField("commonName", commonName).Debug("DummyCert") + log.Debugf("ca DummyCert: %v", commonName) template := &x509.Certificate{ SerialNumber: big.NewInt(time.Now().UnixNano() / 100000), Subject: pkix.Name{ diff --git a/proxy/helper.go b/proxy/helper.go index 793a1fb..7555fb7 100644 --- a/proxy/helper.go +++ b/proxy/helper.go @@ -104,7 +104,7 @@ func getTlsKeyLogWriter() io.Writer { writer, err := os.OpenFile(logfile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) if err != nil { - log.WithField("in", "getTlsKeyLogWriter").Debug(err) + log.Debugf("getTlsKeyLogWriter OpenFile error: %v", err) return } diff --git a/proxy/interceptor.go b/proxy/interceptor.go index 620ccc0..4c8306d 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -1,8 +1,15 @@ package proxy import ( + "bufio" + "context" + "crypto/tls" "net" "net/http" + "strings" + + "github.com/lqqyt2423/go-mitmproxy/cert" + log "github.com/sirupsen/logrus" ) // 拦截 https 流量通用接口 @@ -23,3 +30,148 @@ func (i *forward) Start() error { func (i *forward) Dial(req *http.Request) (net.Conn, error) { return net.Dial("tcp", req.Host) } + +// 模拟了标准库中 server 运行,目的是仅通过当前进程内存转发 socket 数据,不需要经过 tcp 或 unix socket + +type pipeAddr struct { + remoteAddr string +} + +func (pipeAddr) Network() string { return "pipe" } +func (a *pipeAddr) String() string { return a.remoteAddr } + +// add Peek method for conn +type pipeConn struct { + net.Conn + r *bufio.Reader + host string + remoteAddr string + connContext *ConnContext +} + +func newPipeConn(c net.Conn, req *http.Request) *pipeConn { + return &pipeConn{ + Conn: c, + r: bufio.NewReader(c), + host: req.Host, + remoteAddr: req.RemoteAddr, + connContext: req.Context().Value(connContextKey).(*ConnContext), + } +} + +func (c *pipeConn) Peek(n int) ([]byte, error) { + return c.r.Peek(n) +} + +func (c *pipeConn) Read(data []byte) (int, error) { + return c.r.Read(data) +} + +func (c *pipeConn) RemoteAddr() net.Addr { + return &pipeAddr{remoteAddr: c.remoteAddr} +} + +// 建立客户端和服务端通信的通道 +func newPipes(req *http.Request) (net.Conn, *pipeConn) { + client, srv := net.Pipe() + server := newPipeConn(srv, req) + return client, server +} + +// mock net.Listener +type middleListener struct { + connChan chan net.Conn +} + +func (l *middleListener) Accept() (net.Conn, error) { return <-l.connChan, nil } +func (l *middleListener) Close() error { return nil } +func (l *middleListener) Addr() net.Addr { return nil } + +// middle: man-in-the-middle server +type middle struct { + proxy *Proxy + ca *cert.CA + listener *middleListener + server *http.Server +} + +func newMiddle(proxy *Proxy) (interceptor, error) { + ca, err := cert.NewCA(proxy.Opts.CaRootPath) + if err != nil { + return nil, err + } + + m := &middle{ + proxy: proxy, + ca: ca, + listener: &middleListener{ + connChan: make(chan net.Conn), + }, + } + + server := &http.Server{ + Handler: m, + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, connContextKey, c.(*tls.Conn).NetConn().(*pipeConn).connContext) + }, + TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 + TLSConfig: &tls.Config{ + GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + log.Debugf("middle GetCertificate ServerName: %v\n", chi.ServerName) + return ca.GetCert(chi.ServerName) + }, + }, + } + m.server = server + return m, nil +} + +func (m *middle) Start() error { + return m.server.ServeTLS(m.listener, "", "") +} + +// todo: should block until ServerConnected +func (m *middle) Dial(req *http.Request) (net.Conn, error) { + pipeClientConn, pipeServerConn := newPipes(req) + go m.intercept(pipeServerConn) + return pipeClientConn, nil +} + +func (m *middle) ServeHTTP(res http.ResponseWriter, req *http.Request) { + if strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { + // wss + defaultWebSocket.wss(res, req) + return + } + + if req.URL.Scheme == "" { + req.URL.Scheme = "https" + } + if req.URL.Host == "" { + req.URL.Host = req.Host + } + m.proxy.ServeHTTP(res, req) +} + +// 解析 connect 流量 +// 如果是 tls 流量,则进入 listener.Accept => Middle.ServeHTTP +// 否则很可能是 ws 流量 +func (m *middle) intercept(pipeServerConn *pipeConn) { + buf, err := pipeServerConn.Peek(3) + if err != nil { + log.Errorf("Peek error: %v\n", err) + pipeServerConn.Close() + return + } + + // https://github.com/mitmproxy/mitmproxy/blob/main/mitmproxy/net/tls.py is_tls_record_magic + if buf[0] == 0x16 && buf[1] == 0x03 && buf[2] <= 0x03 { + // tls + pipeServerConn.connContext.ClientConn.Tls = true + pipeServerConn.connContext.initHttpsServerConn() + m.listener.connChan <- pipeServerConn + } else { + // ws + defaultWebSocket.ws(pipeServerConn, pipeServerConn.host) + } +} diff --git a/proxy/middle.go b/proxy/middle.go deleted file mode 100644 index 5dbc641..0000000 --- a/proxy/middle.go +++ /dev/null @@ -1,160 +0,0 @@ -package proxy - -import ( - "bufio" - "context" - "crypto/tls" - "net" - "net/http" - "strings" - - "github.com/lqqyt2423/go-mitmproxy/cert" - log "github.com/sirupsen/logrus" -) - -// 模拟了标准库中 server 运行,目的是仅通过当前进程内存转发 socket 数据,不需要经过 tcp 或 unix socket - -type pipeAddr struct { - remoteAddr string -} - -func (pipeAddr) Network() string { return "pipe" } -func (a *pipeAddr) String() string { return a.remoteAddr } - -// add Peek method for conn -type pipeConn struct { - net.Conn - r *bufio.Reader - host string - remoteAddr string - connContext *ConnContext -} - -func newPipeConn(c net.Conn, req *http.Request) *pipeConn { - return &pipeConn{ - Conn: c, - r: bufio.NewReader(c), - host: req.Host, - remoteAddr: req.RemoteAddr, - connContext: req.Context().Value(connContextKey).(*ConnContext), - } -} - -func (c *pipeConn) Peek(n int) ([]byte, error) { - return c.r.Peek(n) -} - -func (c *pipeConn) Read(data []byte) (int, error) { - return c.r.Read(data) -} - -func (c *pipeConn) RemoteAddr() net.Addr { - return &pipeAddr{remoteAddr: c.remoteAddr} -} - -// 建立客户端和服务端通信的通道 -func newPipes(req *http.Request) (net.Conn, *pipeConn) { - client, srv := net.Pipe() - server := newPipeConn(srv, req) - return client, server -} - -// mock net.Listener -type middleListener struct { - connChan chan net.Conn -} - -func (l *middleListener) Accept() (net.Conn, error) { return <-l.connChan, nil } -func (l *middleListener) Close() error { return nil } -func (l *middleListener) Addr() net.Addr { return nil } - -// middle: man-in-the-middle server -type middle struct { - proxy *Proxy - ca *cert.CA - listener *middleListener - server *http.Server -} - -func newMiddle(proxy *Proxy) (interceptor, error) { - ca, err := cert.NewCA(proxy.Opts.CaRootPath) - if err != nil { - return nil, err - } - - m := &middle{ - proxy: proxy, - ca: ca, - listener: &middleListener{ - connChan: make(chan net.Conn), - }, - } - - server := &http.Server{ - Handler: m, - ConnContext: func(ctx context.Context, c net.Conn) context.Context { - return context.WithValue(ctx, connContextKey, c.(*tls.Conn).NetConn().(*pipeConn).connContext) - }, - TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2 - TLSConfig: &tls.Config{ - GetCertificate: func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { - log.Debugf("Middle GetCertificate ServerName: %v\n", chi.ServerName) - return ca.GetCert(chi.ServerName) - }, - }, - } - m.server = server - return m, nil -} - -func (m *middle) Start() error { - return m.server.ServeTLS(m.listener, "", "") -} - -// todo: should block until ServerConnected -func (m *middle) Dial(req *http.Request) (net.Conn, error) { - pipeClientConn, pipeServerConn := newPipes(req) - go m.intercept(pipeServerConn) - return pipeClientConn, nil -} - -func (m *middle) ServeHTTP(res http.ResponseWriter, req *http.Request) { - if strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { - // wss - defaultWebSocket.wss(res, req) - return - } - - if req.URL.Scheme == "" { - req.URL.Scheme = "https" - } - if req.URL.Host == "" { - req.URL.Host = req.Host - } - m.proxy.ServeHTTP(res, req) -} - -// 解析 connect 流量 -// 如果是 tls 流量,则进入 listener.Accept => Middle.ServeHTTP -// 否则很可能是 ws 流量 -func (m *middle) intercept(pipeServerConn *pipeConn) { - log := log.WithField("in", "middle.intercept").WithField("host", pipeServerConn.host) - - buf, err := pipeServerConn.Peek(3) - if err != nil { - log.Errorf("Peek error: %v\n", err) - pipeServerConn.Close() - return - } - - // https://github.com/mitmproxy/mitmproxy/blob/main/mitmproxy/net/tls.py is_tls_record_magic - if buf[0] == 0x16 && buf[1] == 0x03 && buf[2] <= 0x03 { - // tls - pipeServerConn.connContext.ClientConn.Tls = true - pipeServerConn.connContext.initHttpsServerConn() - m.listener.connChan <- pipeServerConn - } else { - // ws - defaultWebSocket.ws(pipeServerConn, pipeServerConn.host) - } -} diff --git a/proxy/proxy.go b/proxy/proxy.go index 796e68e..ab0f641 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -107,8 +107,6 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { "method": req.Method, }) - log.Debug("receive request") - if !req.URL.IsAbs() || req.URL.Host == "" { res.WriteHeader(400) _, err := io.WriteString(res, "此为代理服务器,不能直接发起请求") @@ -259,8 +257,6 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { "host": req.Host, }) - log.Debug("receive connect") - conn, err := proxy.interceptor.Dial(req) if err != nil { log.Error(err)