|
|
|
@ -1,10 +1,13 @@
|
|
|
|
|
package proxy
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bufio"
|
|
|
|
|
"crypto/tls"
|
|
|
|
|
"net"
|
|
|
|
|
"net/http"
|
|
|
|
|
"net/http/httputil"
|
|
|
|
|
"os"
|
|
|
|
|
"strings"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
mock_conn "github.com/jordwest/mock-conn"
|
|
|
|
@ -37,12 +40,14 @@ type ioRes struct {
|
|
|
|
|
type conn struct {
|
|
|
|
|
*mock_conn.End
|
|
|
|
|
|
|
|
|
|
Host string // remote host
|
|
|
|
|
readErrChan chan error // Read 方法提前返回时的错误
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func newConn(end *mock_conn.End) *conn {
|
|
|
|
|
func newConn(end *mock_conn.End, host string) *conn {
|
|
|
|
|
return &conn{
|
|
|
|
|
End: end,
|
|
|
|
|
Host: host,
|
|
|
|
|
readErrChan: make(chan error),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -84,7 +89,9 @@ func (c *conn) Read(data []byte) (int, error) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *conn) SetDeadline(t time.Time) error {
|
|
|
|
|
log.Warnf("SetDeadline %v\n", t)
|
|
|
|
|
if !t.Equal(time.Time{}) {
|
|
|
|
|
log.WithField("host", c.Host).Warnf("SetDeadline %v\n", t)
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -103,10 +110,28 @@ func (c *conn) SetReadDeadline(t time.Time) error {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *conn) SetWriteDeadline(t time.Time) error {
|
|
|
|
|
log.Warnf("SetWriteDeadline %v\n", t)
|
|
|
|
|
log.WithField("host", c.Host).Warnf("SetWriteDeadline %v\n", t)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// wrap conn for peek
|
|
|
|
|
type connBuf struct {
|
|
|
|
|
*conn
|
|
|
|
|
r *bufio.Reader
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func newConnBuf(c *conn) *connBuf {
|
|
|
|
|
return &connBuf{c, bufio.NewReader(c)}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (b *connBuf) Peek(n int) ([]byte, error) {
|
|
|
|
|
return b.r.Peek(n)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (b *connBuf) Read(data []byte) (int, error) {
|
|
|
|
|
return b.r.Read(data)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type MitmMemory struct {
|
|
|
|
|
Proxy *Proxy
|
|
|
|
|
CA *cert.CA
|
|
|
|
@ -153,12 +178,85 @@ func (m *MitmMemory) Start() error {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *MitmMemory) Dial(host string) (net.Conn, error) {
|
|
|
|
|
log := log.WithField("in", "MitmMemory.Dial").WithField("host", host)
|
|
|
|
|
pipes := mock_conn.NewConn()
|
|
|
|
|
m.Listener.(*listener).connChan <- newConn(pipes.Server)
|
|
|
|
|
return newConn(pipes.Client), nil
|
|
|
|
|
|
|
|
|
|
// 如果是 tls 流量,则进入 listener.Accept => MitmMemory.ServeHTTP
|
|
|
|
|
// 否则很可能是 ws 流量,直接转发流量
|
|
|
|
|
go func() {
|
|
|
|
|
conn := newConn(pipes.Server, host)
|
|
|
|
|
connb := newConnBuf(conn)
|
|
|
|
|
buf, err := connb.Peek(3)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("Peek error: %v\n", err)
|
|
|
|
|
connb.Close()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// tls
|
|
|
|
|
if buf[0] == 0x16 && buf[1] == 0x03 && (buf[2] >= 0x0 || buf[2] <= 0x03) {
|
|
|
|
|
m.Listener.(*listener).connChan <- connb
|
|
|
|
|
} else {
|
|
|
|
|
// websocket ws://
|
|
|
|
|
log.Debug("begin websocket ws://")
|
|
|
|
|
defer connb.Close()
|
|
|
|
|
remoteConn, err := net.Dial("tcp", host)
|
|
|
|
|
if err != nil {
|
|
|
|
|
if !ignoreErr(log, err) {
|
|
|
|
|
log.Error(err)
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer remoteConn.Close()
|
|
|
|
|
transfer(log, connb, remoteConn)
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
return newConn(pipes.Client, host), nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *MitmMemory) ServeHTTP(res http.ResponseWriter, req *http.Request) {
|
|
|
|
|
log := log.WithField("in", "MitmMemory.ServeHTTP").WithField("host", req.Host)
|
|
|
|
|
|
|
|
|
|
// websocket wss://
|
|
|
|
|
if strings.EqualFold(req.Header.Get("Connection"), "Upgrade") && strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
|
|
|
|
|
log.Debug("begin websocket wss://")
|
|
|
|
|
|
|
|
|
|
upgradeBuf, err := httputil.DumpRequest(req, false)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("DumpRequest: %v\n", err)
|
|
|
|
|
res.WriteHeader(502)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cconn, _, err := res.(http.Hijacker).Hijack()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("Hijack: %v\n", err)
|
|
|
|
|
res.WriteHeader(502)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer cconn.Close()
|
|
|
|
|
|
|
|
|
|
host := req.Host
|
|
|
|
|
if !strings.Contains(host, ":") {
|
|
|
|
|
host = host + ":443"
|
|
|
|
|
}
|
|
|
|
|
conn, err := tls.Dial("tcp", host, nil)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("tls.Dial: %v\n", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer conn.Close()
|
|
|
|
|
|
|
|
|
|
_, err = conn.Write(upgradeBuf)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("wss upgrade: %v\n", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
transfer(log, conn, cconn)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if req.URL.Scheme == "" {
|
|
|
|
|
req.URL.Scheme = "https"
|
|
|
|
|
}
|
|
|
|
|