diff --git a/README.md b/README.md index b673390..f5a5171 100644 --- a/README.md +++ b/README.md @@ -3,4 +3,5 @@ ## TODO - [x] http handler +- [x] http connect - [ ] https handler diff --git a/bin/main.go b/bin/main.go index b086a7a..8700a72 100644 --- a/bin/main.go +++ b/bin/main.go @@ -7,5 +7,8 @@ import ( ) func main() { - log.Fatal(proxy.NewProxy().Start()) + opts := &proxy.Options{ + Addr: ":8080", + } + log.Fatal(proxy.NewProxy(opts).Start()) } diff --git a/proxy.go b/proxy.go index b9da61b..bf80065 100644 --- a/proxy.go +++ b/proxy.go @@ -3,10 +3,15 @@ package proxy import ( "io" "log" + "net" "net/http" "time" ) +type Options struct { + Addr string +} + type Proxy struct { Server *http.Server } @@ -17,6 +22,11 @@ func (proxy *Proxy) Start() error { } func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { + if req.Method == "CONNECT" { + proxy.handleConnect(res, req) + return + } + if !req.URL.IsAbs() || req.URL.Host == "" { res.WriteHeader(400) _, err := io.WriteString(res, "此为代理服务器,不能直接发起请求") @@ -28,37 +38,84 @@ func (proxy *Proxy) ServeHTTP(res http.ResponseWriter, req *http.Request) { start := time.Now() - proxyReq, _ := http.NewRequest(req.Method, req.URL.String(), req.Body) + proxyReq, err := http.NewRequest(req.Method, req.URL.String(), req.Body) + if err != nil { + log.Printf("error: %v", err) + res.WriteHeader(502) + return + } // TODO: handle Proxy- header for key, value := range req.Header { proxyReq.Header[key] = value } - proxyRes, _ := http.DefaultClient.Do(proxyReq) + proxyRes, err := http.DefaultClient.Do(proxyReq) + if err != nil { + log.Printf("error: %v", err) + res.WriteHeader(502) + return + } + defer proxyRes.Body.Close() for key, value := range proxyRes.Header { res.Header()[key] = value } res.WriteHeader(proxyRes.StatusCode) - _, err := io.Copy(res, proxyRes.Body) + _, err = io.Copy(res, proxyRes.Body) if err != nil { log.Printf("error: %v", err) return } - err = proxyRes.Body.Close() + log.Printf("%v %v %v - %v ms", req.Method, req.URL.String(), proxyRes.StatusCode, time.Since(start).Milliseconds()) +} + +func (proxy *Proxy) handleConnect(res http.ResponseWriter, req *http.Request) { + log.Printf("CONNECT: %v\n", req.Host) + + conn, err := net.Dial("tcp", req.Host) if err != nil { log.Printf("error: %v", err) + res.WriteHeader(502) return } + defer conn.Close() - log.Printf("%v %v %v - %v ms", req.Method, req.URL.String(), proxyRes.StatusCode, time.Since(start).Milliseconds()) + cconn, _, err := res.(http.Hijacker).Hijack() + if err != nil { + log.Printf("error: %v", err) + res.WriteHeader(502) + return + } + defer cconn.Close() + + _, err = io.WriteString(cconn, "HTTP/1.1 200 Connection Established\r\n\r\n") + if err != nil { + log.Printf("error: %v", err) + return + } + + ch := make(chan bool) + go func() { + _, err := io.Copy(conn, cconn) + if err != nil { + log.Printf("error: %v", err) + } + ch <- true + }() + + _, err = io.Copy(cconn, conn) + if err != nil { + log.Printf("error: %v", err) + } + + <-ch } -func NewProxy() *Proxy { +func NewProxy(opts *Options) *Proxy { proxy := new(Proxy) proxy.Server = &http.Server{ - Addr: ":8080", + Addr: opts.Addr, Handler: proxy, } return proxy