diff --git a/addon/web/conn.go b/addon/web/conn.go new file mode 100644 index 0000000..62b9d69 --- /dev/null +++ b/addon/web/conn.go @@ -0,0 +1,104 @@ +package web + +import ( + "encoding/json" + "sync" + + "github.com/gorilla/websocket" + "github.com/lqqyt2423/go-mitmproxy/flow" +) + +type concurrentConn struct { + conn *websocket.Conn + mu sync.Mutex + + waitChans map[string]chan interface{} + waitChansMu sync.Mutex +} + +func newConn(c *websocket.Conn) *concurrentConn { + return &concurrentConn{ + conn: c, + waitChans: make(map[string]chan interface{}), + } +} + +func (c *concurrentConn) writeMessage(msg *message, f *flow.Flow) { + c.mu.Lock() + err := c.conn.WriteMessage(websocket.BinaryMessage, msg.bytes()) + c.mu.Unlock() + if err != nil { + log.Error(err) + return + } + + c.waitIntercept(f, msg) +} + +func (c *concurrentConn) readloop() { + for { + mt, data, err := c.conn.ReadMessage() + if err != nil { + log.Error(err) + break + } + + if mt != websocket.BinaryMessage { + log.Warn("not BinaryMessage, skip") + continue + } + + msg := parseMessage(data) + if msg == nil { + log.Warn("parseMessage error, skip") + continue + } + + if msg.mType == messageTypeChangeRequest { + req := new(flow.Request) + err := json.Unmarshal(msg.content, req) + if err != nil { + log.Error(err) + continue + } + + ch := c.initWaitChan(msg.id.String()) + go func(req *flow.Request, ch chan<- interface{}) { + ch <- req + }(req, ch) + } + } +} + +func (c *concurrentConn) initWaitChan(key string) chan interface{} { + c.waitChansMu.Lock() + defer c.waitChansMu.Unlock() + + if ch, ok := c.waitChans[key]; ok { + return ch + } + ch := make(chan interface{}) + c.waitChans[key] = ch + return ch +} + +// 是否拦截 +func (c *concurrentConn) isIntercpt(f *flow.Flow, after *message) bool { + return false +} + +// 拦截 +func (c *concurrentConn) waitIntercept(f *flow.Flow, after *message) { + if !c.isIntercpt(f, after) { + return + } + + log.Infof("waiting Intercept: %s\n", f.Request.URL) + ch := c.initWaitChan(f.Id.String()) + req := (<-ch).(*flow.Request) + log.Infof("waited Intercept: %s\n", f.Request.URL) + + f.Request.Method = req.Method + f.Request.URL = req.URL + f.Request.Header = req.Header +} diff --git a/addon/web/message.go b/addon/web/message.go index 50f2ca1..b34a459 100644 --- a/addon/web/message.go +++ b/addon/web/message.go @@ -10,24 +10,54 @@ import ( const messageVersion = 1 +type messageType int + const ( - messageTypeRequest = 1 - messageTypeResponse = 2 - messageTypeResponseBody = 3 + messageTypeRequest messageType = 1 + messageTypeResponse messageType = 2 + messageTypeResponseBody messageType = 3 + + messageTypeChangeRequest messageType = 11 ) +func validMessageType(t byte) bool { + if t == byte(messageTypeRequest) || t == byte(messageTypeResponse) || t == byte(messageTypeResponseBody) || t == byte(messageTypeChangeRequest) { + return true + } + return false +} + type message struct { - messageType int - id uuid.UUID - content []byte + mType messageType + id uuid.UUID + content []byte } -func newMessage(messageType int, id uuid.UUID, content []byte) *message { +func newMessage(mType messageType, id uuid.UUID, content []byte) *message { return &message{ - messageType: messageType, - id: id, - content: content, + mType: mType, + id: id, + content: content, + } +} + +func parseMessage(data []byte) *message { + if len(data) < 38 { + return nil + } + if data[0] != messageVersion { + return nil } + if !validMessageType(data[1]) { + return nil + } + + id, err := uuid.FromString(string(data[2:38])) + if err != nil { + return nil + } + + return newMessage(messageType(data[1]), id, data[38:]) } func newMessageRequest(f *flow.Flow) *message { @@ -53,7 +83,7 @@ func newMessageResponseBody(f *flow.Flow) *message { func (m *message) bytes() []byte { buf := bytes.NewBuffer(make([]byte, 0)) buf.WriteByte(byte(messageVersion)) - buf.WriteByte(byte(m.messageType)) + buf.WriteByte(byte(m.mType)) buf.WriteString(m.id.String()) // len: 36 buf.Write(m.content) return buf.Bytes() diff --git a/addon/web/web.go b/addon/web/web.go index fec9c94..392d541 100644 --- a/addon/web/web.go +++ b/addon/web/web.go @@ -19,30 +19,14 @@ func (web *WebAddon) echo(w http.ResponseWriter, r *http.Request) { return } - web.addConn(c) + conn := newConn(c) + web.addConn(conn) defer func() { - web.removeConn(c) + web.removeConn(conn) c.Close() }() - for { - mt, message, err := c.ReadMessage() - if err != nil { - log.Println("read:", err) - break - } - log.Printf("recv: %s", message) - err = c.WriteMessage(mt, message) - if err != nil { - log.Println("write:", err) - break - } - } -} - -type concurrentConn struct { - conn *websocket.Conn - mu sync.Mutex + conn.readloop() } type WebAddon struct { @@ -82,19 +66,19 @@ func NewWebAddon() *WebAddon { return web } -func (web *WebAddon) addConn(c *websocket.Conn) { +func (web *WebAddon) addConn(c *concurrentConn) { web.connsMu.Lock() - web.conns = append(web.conns, &concurrentConn{conn: c}) + web.conns = append(web.conns, c) web.connsMu.Unlock() } -func (web *WebAddon) removeConn(conn *websocket.Conn) { +func (web *WebAddon) removeConn(conn *concurrentConn) { web.connsMu.Lock() defer web.connsMu.Unlock() index := -1 for i, c := range web.conns { - if conn == c.conn { + if conn == c { index = i break } @@ -106,37 +90,37 @@ func (web *WebAddon) removeConn(conn *websocket.Conn) { web.conns = append(web.conns[:index], web.conns[index+1:]...) } -func (web *WebAddon) sendFlow(msgFn func() *message) { +func (web *WebAddon) sendFlow(f *flow.Flow, msgFn func() *message) bool { web.connsMu.RLock() conns := web.conns web.connsMu.RUnlock() if len(conns) == 0 { - return + return false } msg := msgFn() for _, c := range conns { - c.mu.Lock() - c.conn.WriteMessage(websocket.BinaryMessage, msg.bytes()) - c.mu.Unlock() + c.writeMessage(msg, f) } + + return true } func (web *WebAddon) Request(f *flow.Flow) { - web.sendFlow(func() *message { + web.sendFlow(f, func() *message { return newMessageRequest(f) }) } func (web *WebAddon) Responseheaders(f *flow.Flow) { - web.sendFlow(func() *message { + web.sendFlow(f, func() *message { return newMessageResponse(f) }) } func (web *WebAddon) Response(f *flow.Flow) { - web.sendFlow(func() *message { + web.sendFlow(f, func() *message { return newMessageResponseBody(f) }) } diff --git a/flow/flow.go b/flow/flow.go index 9c2f6aa..7903f20 100644 --- a/flow/flow.go +++ b/flow/flow.go @@ -2,6 +2,7 @@ package flow import ( "encoding/json" + "errors" "net/http" "net/url" @@ -31,6 +32,54 @@ func (req *Request) MarshalJSON() ([]byte, error) { return json.Marshal(r) } +func (req *Request) UnmarshalJSON(data []byte) error { + r := make(map[string]interface{}) + err := json.Unmarshal(data, &r) + if err != nil { + return err + } + + rawurl, ok := r["url"].(string) + if !ok { + return errors.New("url parse error") + } + u, err := url.Parse(rawurl) + if err != nil { + return err + } + + rawheader, ok := r["header"].(map[string]interface{}) + if !ok { + return errors.New("rawheader parse error") + } + + header := make(map[string][]string) + for k, v := range rawheader { + vals, ok := v.([]interface{}) + if !ok { + return errors.New("header parse error") + } + + svals := make([]string, 0) + for _, val := range vals { + sval, ok := val.(string) + if !ok { + return errors.New("header parse error") + } + svals = append(svals, sval) + } + header[k] = svals + } + + *req = Request{ + Method: r["method"].(string), + URL: u, + Proto: r["proto"].(string), + Header: header, + } + return nil +} + func NewRequest(req *http.Request) *Request { return &Request{ Method: req.Method,