Skip to content

Commit

Permalink
Refine BDAT implementation
Browse files Browse the repository at this point in the history
- Get rid of copyN rip-off and replace with io.LimitReader.
- Get rid of double-synchronization around dataResult.
- Check chunk size before copying bytes to backend.
- Reset connection if max. message size is exceeded
  • Loading branch information
foxcpp committed Jun 30, 2020
1 parent af56296 commit 998678d
Showing 1 changed file with 20 additions and 77 deletions.
97 changes: 20 additions & 77 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,60 +649,16 @@ func (c *Conn) handleData(arg string) {
c.WriteResponse(code, enhancedCode, msg)
}

// This is a version of io.CopyN with a subtle difference that it returns the number
// of messages read from src and not the amount of written bytes.
//
// This subtle difference is important for use with io.PipeWriter as a destination
// in handleBdat.
func copyN(dst io.Writer, src io.Reader, n int64) (consumed int64, err error) {
src = io.LimitReader(src, n)

size := 32 * 1024
if l, ok := src.(*io.LimitedReader); ok && int64(size) > l.N {
if l.N < 1 {
size = 1
} else {
size = int(l.N)
}
}
buf := make([]byte, size)

for {
nr, er := src.Read(buf)
if nr > 0 {
consumed += int64(nr)
nw, ew := dst.Write(buf[0:nr])
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}

if consumed < n && err == nil {
// src stopped early; must have been EOF.
err = io.EOF
}

return consumed, err
}

func (c *Conn) handleBdat(arg string) {
args := strings.Split(arg, " ")
args := strings.Fields(arg)
if len(args) == 0 {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Missing chunk size argument")
return
}
if len(args) > 2 {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Too many arguments")
return
}

if !c.fromReceived || len(c.recipients) == 0 {
c.WriteResponse(502, EnhancedCode{5, 5, 1}, "Missing RCPT TO command.")
Expand All @@ -725,6 +681,12 @@ func (c *Conn) handleBdat(arg string) {
return
}

if c.server.MaxMessageBytes != 0 && c.bytesReceived+int(size) > c.server.MaxMessageBytes {
c.WriteResponse(552, EnhancedCode{5, 3, 4}, "Max message size exceeded")
c.reset()
return
}

if c.bdatStatus == nil && c.server.LMTP {
c.bdatStatus = c.createStatusCollector()
}
Expand All @@ -733,7 +695,7 @@ func (c *Conn) handleBdat(arg string) {
var r *io.PipeReader
r, c.bdatPipe = io.Pipe()

c.dataResult = make(chan error)
c.dataResult = make(chan error, 1)

go func() {
defer func() {
Expand Down Expand Up @@ -765,30 +727,16 @@ func (c *Conn) handleBdat(arg string) {
}
}

select {
case <-c.dataResult:
// handleBdat goroutine wants us to sent an error here.
c.dataResult <- err
default:
}

if err != nil {
r.CloseWithError(err)
} else {
r.Close()
}
c.dataResult <- err
r.CloseWithError(err)
}()
}

// Using copyN that returns amount of bytes read from c.text.R instead of
// amount of bytes written is important here because read may succeed and
// write not in case of early error returned by Session.Data. In this case
// io.CopyN call below needs to calculate correct amount of bytes to
// discard.
consumed, err := copyN(c.bdatPipe, c.text.R, int64(size))
chunk := io.LimitedReader{R: c.text.R, N: int64(size)}
_, err = io.Copy(c.bdatPipe, &chunk)
if err != nil {
// Backend might return an error without consuming the whole chunk.
io.CopyN(ioutil.Discard, c.text.R, int64(size)-consumed)
io.Copy(ioutil.Discard, &chunk)

code, enhancedCode, msg := toSMTPStatus(err)
c.WriteResponse(code, enhancedCode, msg)
Expand All @@ -797,14 +745,11 @@ func (c *Conn) handleBdat(arg string) {
c.Close()
}

c.reset()
return
}

c.bytesReceived += int(consumed)
if c.server.MaxMessageBytes != 0 && c.bytesReceived > c.server.MaxMessageBytes {
c.WriteResponse(552, EnhancedCode{5, 3, 4}, "Max message size exceeded")
return
}
c.bytesReceived += int(size)

if last {
c.bdatPipe.Close()
Expand All @@ -814,9 +759,7 @@ func (c *Conn) handleBdat(arg string) {
// obtain the error if any.
//
// io.Pipe CloseWithError cannot be used at this point since we are not
// writing anymore. Therefore we signal that we want an error sent to this
// channel and then wait for it.
c.dataResult <- nil
// writing anymore.
err := <-c.dataResult

if c.server.LMTP {
Expand Down

0 comments on commit 998678d

Please sign in to comment.