https add server hook

addon-dailer
lqqyt2423 2 years ago
parent 46f165105d
commit 2e0c62a08b

@ -23,7 +23,7 @@ func NewConnContext(c net.Conn) *ConnContext {
}
}
func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.Conn) net.Conn, whenServerConnected func()) {
func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.Conn) net.Conn, whenConnected func()) {
if connCtx.Server != nil {
return
}
@ -48,7 +48,7 @@ func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.C
cw := connWrap(c)
server.Conn = cw
defer whenServerConnected()
defer whenConnected()
return cw, nil
},
ForceAttemptHTTP2: false, // disable http2
@ -67,7 +67,7 @@ func (connCtx *ConnContext) InitHttpServer(sslInsecure bool, connWrap func(net.C
connCtx.Server = server
}
func (connCtx *ConnContext) InitHttpsServer(sslInsecure bool) {
func (connCtx *ConnContext) InitHttpsServer(sslInsecure bool, connWrap func(net.Conn) net.Conn, whenConnected func()) {
if connCtx.Server != nil {
return
}
@ -80,18 +80,33 @@ func (connCtx *ConnContext) InitHttpsServer(sslInsecure bool) {
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
// todo: change here
DialContext: (&net.Dialer{
// Timeout: 30 * time.Second,
// KeepAlive: 30 * time.Second,
}).DialContext,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
log.Debugln("in https DialTLSContext")
plainConn, err := (&net.Dialer{}).DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
cw := connWrap(plainConn)
server.Conn = cw
whenConnected()
firstTLSHost, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
cfg := &tls.Config{
InsecureSkipVerify: sslInsecure,
KeyLogWriter: GetTlsKeyLogWriter(),
ServerName: firstTLSHost,
}
tlsConn := tls.Client(cw, cfg)
return tlsConn, nil
},
ForceAttemptHTTP2: false, // disable http2
DisableCompression: true, // To get the original response from the server, set Transport.DisableCompression to true.
TLSClientConfig: &tls.Config{
InsecureSkipVerify: sslInsecure,
KeyLogWriter: GetTlsKeyLogWriter(),
},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// 禁止自动重定向

@ -32,14 +32,14 @@ func (pipeAddr) Network() string { return "pipe" }
func (a *pipeAddr) String() string { return a.remoteAddr }
// 建立客户端和服务端通信的通道
func newPipes(req *http.Request) (net.Conn, *connBuf) {
func newPipes(req *http.Request) (net.Conn, *pipeConn) {
client, srv := net.Pipe()
server := newConnBuf(srv, req)
server := newPipeConn(srv, req)
return client, server
}
// add Peek method for conn
type connBuf struct {
type pipeConn struct {
net.Conn
r *bufio.Reader
host string
@ -47,8 +47,8 @@ type connBuf struct {
connContext *flow.ConnContext
}
func newConnBuf(c net.Conn, req *http.Request) *connBuf {
return &connBuf{
func newPipeConn(c net.Conn, req *http.Request) *pipeConn {
return &pipeConn{
Conn: c,
r: bufio.NewReader(c),
host: req.Host,
@ -57,16 +57,16 @@ func newConnBuf(c net.Conn, req *http.Request) *connBuf {
}
}
func (b *connBuf) Peek(n int) ([]byte, error) {
return b.r.Peek(n)
func (c *pipeConn) Peek(n int) ([]byte, error) {
return c.r.Peek(n)
}
func (b *connBuf) Read(data []byte) (int, error) {
return b.r.Read(data)
func (c *pipeConn) Read(data []byte) (int, error) {
return c.r.Read(data)
}
func (b *connBuf) RemoteAddr() net.Addr {
return &pipeAddr{remoteAddr: b.remoteAddr}
func (c *pipeConn) RemoteAddr() net.Addr {
return &pipeAddr{remoteAddr: c.remoteAddr}
}
// Middle: man-in-the-middle
@ -93,7 +93,7 @@ func NewMiddle(proxy *Proxy, caPath string) (Interceptor, error) {
IdleTimeout: 5 * time.Second,
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, flow.ConnContextKey, c.(*tls.Conn).NetConn().(*connBuf).connContext)
return context.WithValue(ctx, flow.ConnContextKey, c.(*tls.Conn).NetConn().(*pipeConn).connContext)
},
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), // disable http2
@ -116,9 +116,9 @@ func (m *Middle) Start() error {
}
func (m *Middle) Dial(req *http.Request) (net.Conn, error) {
clientConn, serverConn := newPipes(req)
go m.intercept(serverConn)
return clientConn, nil
pipeClientConn, pipeServerConn := newPipes(req)
go m.intercept(pipeServerConn)
return pipeClientConn, nil
}
func (m *Middle) ServeHTTP(res http.ResponseWriter, req *http.Request) {
@ -140,24 +140,38 @@ func (m *Middle) ServeHTTP(res http.ResponseWriter, req *http.Request) {
// 解析 connect 流量
// 如果是 tls 流量,则进入 listener.Accept => Middle.ServeHTTP
// 否则很可能是 ws 流量
func (m *Middle) intercept(serverConn *connBuf) {
log := log.WithField("in", "Middle.intercept").WithField("host", serverConn.host)
func (m *Middle) intercept(pipeServerConn *pipeConn) {
log := log.WithField("in", "Middle.intercept").WithField("host", pipeServerConn.host)
buf, err := serverConn.Peek(3)
buf, err := pipeServerConn.Peek(3)
if err != nil {
log.Errorf("Peek error: %v\n", err)
serverConn.Close()
pipeServerConn.Close()
return
}
// https://github.com/mitmproxy/mitmproxy/blob/main/mitmproxy/net/tls.py is_tls_record_magic
if buf[0] == 0x16 && buf[1] == 0x03 && buf[2] <= 0x03 {
// tls
serverConn.connContext.Client.Tls = true
serverConn.connContext.InitHttpsServer(m.Proxy.Opts.SslInsecure)
m.Listener.(*middleListener).connChan <- serverConn
pipeServerConn.connContext.Client.Tls = true
pipeServerConn.connContext.InitHttpsServer(
m.Proxy.Opts.SslInsecure,
func(c net.Conn) net.Conn {
return &serverConn{
Conn: c,
proxy: m.Proxy,
connCtx: pipeServerConn.connContext,
}
},
func() {
for _, addon := range m.Proxy.Addons {
addon.ServerConnected(pipeServerConn.connContext)
}
},
)
m.Listener.(*middleListener).connChan <- pipeServerConn
} else {
// ws
DefaultWebSocket.WS(serverConn, serverConn.host)
DefaultWebSocket.WS(pipeServerConn, pipeServerConn.host)
}
}

Loading…
Cancel
Save