use net.Pipe()

addon-dailer
lqqyt2423 4 years ago
parent 8ac3abcd39
commit 4d7e8aca25

@ -6,7 +6,6 @@ require (
github.com/andybalholm/brotli v1.0.1
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e
github.com/gorilla/websocket v1.4.2
github.com/jordwest/mock-conn v0.0.0-20180617021051-4896c6bd1641
github.com/satori/go.uuid v1.2.0
github.com/sirupsen/logrus v1.7.0
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect

@ -6,8 +6,6 @@ github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18h
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jordwest/mock-conn v0.0.0-20180617021051-4896c6bd1641 h1:ChkB2s4mFDekyUUmbNE7qNhennP0rfqF2YZUOGxbhFk=
github.com/jordwest/mock-conn v0.0.0-20180617021051-4896c6bd1641/go.mod h1:AJFEOPtj5Z5z3MAy+0uvjQAH02iRnQr6fnvuHYp/Jek=
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=

@ -39,22 +39,27 @@ func LogErr(log *_log.Entry, err error) (loged bool) {
// 转发流量
// Read a => Write b
// Read b => Write a
func Transfer(log *_log.Entry, a, b io.ReadWriter) {
func Transfer(log *_log.Entry, a, b io.ReadWriteCloser) {
done := make(chan struct{})
defer close(done)
forward := func(dst io.Writer, src io.Reader, ec chan<- error) {
forward := func(dst io.WriteCloser, src io.Reader, ec chan<- error) {
_, err := io.Copy(dst, src)
if v, ok := dst.(*conn); ok {
// 避免内存泄漏
_ = v.Writer.CloseWithError(nil)
if err != nil {
select {
case <-done:
return
case ec <- err:
return
}
}
err = dst.Close()
select {
case <-done:
return
case ec <- err:
return
}
}

@ -3,15 +3,10 @@ package proxy
import (
"bufio"
"crypto/tls"
"errors"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
mock_conn "github.com/jordwest/mock-conn"
"github.com/lqqyt2423/go-mitmproxy/cert"
)
@ -26,123 +21,26 @@ func (l *listener) Accept() (net.Conn, error) { return <-l.connChan, nil }
func (l *listener) Close() error { return nil }
func (l *listener) Addr() net.Addr { return nil }
// mock net.Conn
type conn struct {
mock_conn.End
host string // remote host
// 以下为实现 SetReadDeadline 所需字段:需要确保 Read 方法可以提前返回
// connection: keep-alive 相关
readCanCancel bool // 是否可取消 Read
firstRead bool // 首次调用 Read 初始化
readErrChan chan error // Read 方法提前返回时的错误,总是 os.ErrDeadlineExceeded
readErr error // 底层 End 返回的错误
readDeadline time.Time // SetReadDeadline 设置的时间
chunk chan []byte // Read 和 beginRead 的交互 channel
readDeadlineMu sync.RWMutex
}
var connUnexpected = errors.New("unexpected read error")
// 建立客户端和服务端通信的通道
func newPipes(host string) (client *conn, server *connBuf) {
pipes := mock_conn.NewConn()
client = &conn{
End: *pipes.Client,
host: host,
}
serverConn := &conn{
End: *pipes.Server,
host: host,
readCanCancel: true,
readErrChan: make(chan error),
chunk: make(chan []byte),
}
server = newConnBuf(serverConn)
func newPipes(host string) (net.Conn, *connBuf) {
client, srv := net.Pipe()
server := newConnBuf(srv, host)
return client, server
}
func (c *conn) beginRead(size int) {
buf := make([]byte, size)
for {
n, err := c.End.Read(buf)
if err != nil {
c.readErr = err
close(c.chunk)
return
}
chunk := make([]byte, n)
copy(chunk, buf[:n])
c.chunk <- chunk
}
}
func (c *conn) Read(data []byte) (int, error) {
if !c.readCanCancel {
return c.End.Read(data)
}
if !c.firstRead {
go c.beginRead(len(data))
}
c.firstRead = true
c.readDeadlineMu.RLock()
if !c.readDeadline.Equal(time.Time{}) {
if !c.readDeadline.After(time.Now()) {
c.readDeadlineMu.RUnlock()
return 0, os.ErrDeadlineExceeded
} else {
c.readDeadlineMu.RUnlock()
log.WithField("host", c.host).Warnf("c.readDeadline is future %v\n", c.readDeadline)
return 0, connUnexpected
}
}
c.readDeadlineMu.RUnlock()
select {
case err := <-c.readErrChan:
return 0, err
case chunk, ok := <-c.chunk:
if !ok {
return 0, c.readErr
}
copy(data, chunk)
return len(chunk), nil
}
}
func (c *conn) SetDeadline(t time.Time) error {
log.WithField("host", c.host).Warnf("SetDeadline %v\n", t)
return connUnexpected
}
// http server 标准库实现时,当多个 http 复用底层 socket 时,会调用此方法
func (c *conn) SetReadDeadline(t time.Time) error {
c.readDeadlineMu.Lock()
c.readDeadline = t
c.readDeadlineMu.Unlock()
if !t.Equal(time.Time{}) && !t.After(time.Now()) {
c.readErrChan <- os.ErrDeadlineExceeded
}
return nil
}
func (c *conn) SetWriteDeadline(t time.Time) error {
log.WithField("host", c.host).Warnf("SetWriteDeadline %v\n", t)
return connUnexpected
}
// add Peek method for conn
type connBuf struct {
*conn
r *bufio.Reader
net.Conn
r *bufio.Reader
host string
}
func newConnBuf(c *conn) *connBuf {
return &connBuf{c, bufio.NewReader(c)}
func newConnBuf(c net.Conn, host string) *connBuf {
return &connBuf{
Conn: c,
r: bufio.NewReader(c),
host: host,
}
}
func (b *connBuf) Peek(n int) ([]byte, error) {

Loading…
Cancel
Save