From c41bf7c881a582f13c5d319ce87f245616f5cda6 Mon Sep 17 00:00:00 2001 From: lqqyt2423 <974923609@qq.com> Date: Fri, 11 Dec 2020 18:00:42 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=86=85=E5=AD=98=E6=B3=84?= =?UTF-8?q?=E6=BC=8F=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- proxy/proxy.go | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 1266e9c..0592ecc 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -39,19 +39,35 @@ var ignoreErr = func(log *_log.Entry, err error) bool { func transfer(log *_log.Entry, a, b io.ReadWriter) { done := make(chan struct{}) - go func() { - _, err := io.Copy(a, b) - if err != nil && !ignoreErr(log, err) { - log.Error(err) + defer close(done) + + forward := func(dst io.Writer, src io.Reader, ec chan<- error) { + _, err := io.Copy(dst, src) + + if v, ok := dst.(*conn); ok { + // 避免内存泄漏的关键 + _ = v.Writer.CloseWithError(nil) } - close(done) - }() - _, err := io.Copy(b, a) - if err != nil && !ignoreErr(log, err) { - log.Error(err) + select { + case <-done: + return + case ec <- err: + } + } + + errChan := make(chan error) + go forward(a, b, errChan) + go forward(b, a, errChan) + + for i := 0; i < 2; i++ { + if err := <-errChan; err != nil { + if !ignoreErr(log, err) { + log.Error(err) + } + return // 如果有错误,直接返回 + } } - <-done } type Options struct {