From 998678df76df56e8cfbbe155067c3e275981d7f3 Mon Sep 17 00:00:00 2001 From: "fox.cpp" Date: Tue, 30 Jun 2020 21:21:46 +0300 Subject: [PATCH] Refine BDAT implementation - Get rid of copyN rip-off and replace with io.LimitReader. - Get rid of double-synchronization around dataResult. - Check chunk size before copying bytes to backend. - Reset connection if max. message size is exceeded --- conn.go | 97 ++++++++++++--------------------------------------------- 1 file changed, 20 insertions(+), 77 deletions(-) diff --git a/conn.go b/conn.go index 9a8041c..524c680 100644 --- a/conn.go +++ b/conn.go @@ -649,60 +649,16 @@ func (c *Conn) handleData(arg string) { c.WriteResponse(code, enhancedCode, msg) } -// This is a version of io.CopyN with a subtle difference that it returns the number -// of messages read from src and not the amount of written bytes. -// -// This subtle difference is important for use with io.PipeWriter as a destination -// in handleBdat. -func copyN(dst io.Writer, src io.Reader, n int64) (consumed int64, err error) { - src = io.LimitReader(src, n) - - size := 32 * 1024 - if l, ok := src.(*io.LimitedReader); ok && int64(size) > l.N { - if l.N < 1 { - size = 1 - } else { - size = int(l.N) - } - } - buf := make([]byte, size) - - for { - nr, er := src.Read(buf) - if nr > 0 { - consumed += int64(nr) - nw, ew := dst.Write(buf[0:nr]) - if ew != nil { - err = ew - break - } - if nr != nw { - err = io.ErrShortWrite - break - } - } - if er != nil { - if er != io.EOF { - err = er - } - break - } - } - - if consumed < n && err == nil { - // src stopped early; must have been EOF. - err = io.EOF - } - - return consumed, err -} - func (c *Conn) handleBdat(arg string) { - args := strings.Split(arg, " ") + args := strings.Fields(arg) if len(args) == 0 { c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Missing chunk size argument") return } + if len(args) > 2 { + c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Too many arguments") + return + } if !c.fromReceived || len(c.recipients) == 0 { c.WriteResponse(502, EnhancedCode{5, 5, 1}, "Missing RCPT TO command.") @@ -725,6 +681,12 @@ func (c *Conn) handleBdat(arg string) { return } + if c.server.MaxMessageBytes != 0 && c.bytesReceived+int(size) > c.server.MaxMessageBytes { + c.WriteResponse(552, EnhancedCode{5, 3, 4}, "Max message size exceeded") + c.reset() + return + } + if c.bdatStatus == nil && c.server.LMTP { c.bdatStatus = c.createStatusCollector() } @@ -733,7 +695,7 @@ func (c *Conn) handleBdat(arg string) { var r *io.PipeReader r, c.bdatPipe = io.Pipe() - c.dataResult = make(chan error) + c.dataResult = make(chan error, 1) go func() { defer func() { @@ -765,30 +727,16 @@ func (c *Conn) handleBdat(arg string) { } } - select { - case <-c.dataResult: - // handleBdat goroutine wants us to sent an error here. - c.dataResult <- err - default: - } - - if err != nil { - r.CloseWithError(err) - } else { - r.Close() - } + c.dataResult <- err + r.CloseWithError(err) }() } - // Using copyN that returns amount of bytes read from c.text.R instead of - // amount of bytes written is important here because read may succeed and - // write not in case of early error returned by Session.Data. In this case - // io.CopyN call below needs to calculate correct amount of bytes to - // discard. - consumed, err := copyN(c.bdatPipe, c.text.R, int64(size)) + chunk := io.LimitedReader{R: c.text.R, N: int64(size)} + _, err = io.Copy(c.bdatPipe, &chunk) if err != nil { // Backend might return an error without consuming the whole chunk. - io.CopyN(ioutil.Discard, c.text.R, int64(size)-consumed) + io.Copy(ioutil.Discard, &chunk) code, enhancedCode, msg := toSMTPStatus(err) c.WriteResponse(code, enhancedCode, msg) @@ -797,14 +745,11 @@ func (c *Conn) handleBdat(arg string) { c.Close() } + c.reset() return } - c.bytesReceived += int(consumed) - if c.server.MaxMessageBytes != 0 && c.bytesReceived > c.server.MaxMessageBytes { - c.WriteResponse(552, EnhancedCode{5, 3, 4}, "Max message size exceeded") - return - } + c.bytesReceived += int(size) if last { c.bdatPipe.Close() @@ -814,9 +759,7 @@ func (c *Conn) handleBdat(arg string) { // obtain the error if any. // // io.Pipe CloseWithError cannot be used at this point since we are not - // writing anymore. Therefore we signal that we want an error sent to this - // channel and then wait for it. - c.dataResult <- nil + // writing anymore. err := <-c.dataResult if c.server.LMTP {