diff --git a/proxy/connection.go b/proxy/connection.go index 34e9c40..cc8255f 100644 --- a/proxy/connection.go +++ b/proxy/connection.go @@ -157,31 +157,7 @@ func (connCtx *ConnContext) initServerTcpConn(req *http.Request) error { connCtx.ServerConn = ServerConn ServerConn.Address = connCtx.pipeConn.host - // test is use proxy - clientReq := &http.Request{URL: &url.URL{Scheme: "https", Host: ServerConn.Address}} - - var proxyUrl *url.URL - var err error - - if len(connCtx.proxy.Opts.Upstream) > 0 { - upstreamUrl, _ := url.Parse(connCtx.proxy.Opts.Upstream) - proxyUrl, err = http.ProxyURL(upstreamUrl)(clientReq) - if err != nil { - return err - } - } else { - proxyUrl, err = http.ProxyFromEnvironment(clientReq) - if err != nil { - return err - } - } - - var plainConn net.Conn - if proxyUrl != nil { - plainConn, err = getProxyConn(proxyUrl, ServerConn.Address) - } else { - plainConn, err = (&net.Dialer{}).DialContext(context.Background(), "tcp", ServerConn.Address) - } + plainConn, err := getConnFrom(req.Host, connCtx.proxy.Opts.Upstream) if err != nil { return err } @@ -393,3 +369,31 @@ func getProxyConn(proxyUrl *url.URL, address string) (net.Conn, error) { } return conn, nil } + +func getConnFrom(address string, upstream string) (net.Conn, error) { + clientReq := &http.Request{URL: &url.URL{Scheme: "https", Host: address}} + + var proxyUrl *url.URL + var err error + + if len(upstream) > 0 { + upstreamUrl, _ := url.Parse(upstream) + proxyUrl, err = http.ProxyURL(upstreamUrl)(clientReq) + if err != nil { + return nil, err + } + } else { + proxyUrl, err = http.ProxyFromEnvironment(clientReq) + if err != nil { + return nil, err + } + } + + var conn net.Conn + if proxyUrl != nil { + conn, err = getProxyConn(proxyUrl, address) + } else { + conn, err = (&net.Dialer{}).DialContext(context.Background(), "tcp", address) + } + return conn, err +} diff --git a/proxy/proxy.go b/proxy/proxy.go index ec6ea75..2fb218c 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -25,8 +25,9 @@ type Proxy struct { Version string Addons []Addon - server *http.Server - interceptor *middle + server *http.Server + interceptor *middle + shouldIntercept func(address string) bool } func NewProxy(opts *Options) (*Proxy, error) { @@ -282,7 +283,15 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { "host": req.Host, }) - conn, err := proxy.interceptor.dial(req) + var conn net.Conn + var err error + if proxy.shouldIntercept == nil || proxy.shouldIntercept(req.Host) { + log.Debugf("begin intercept %v", req.Host) + conn, err = proxy.interceptor.dial(req) + } else { + log.Debugf("begin transpond %v", req.Host) + conn, err = getConnFrom(req.Host, proxy.Opts.Upstream) + } if err != nil { log.Error(err) res.WriteHeader(502) @@ -313,3 +322,7 @@ func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { func (proxy *Proxy) GetCertificate() x509.Certificate { return proxy.interceptor.ca.RootCert } + +func (proxy *Proxy) SetShouldInterceptRule(rule func(address string) bool) { + proxy.shouldIntercept = rule +}