Skip to content

Commit

Permalink
Properly handle DATA read timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
kayrus committed Aug 17, 2022
1 parent 42be6af commit 358ef5a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
15 changes: 12 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,18 +673,27 @@ func (c *Conn) handleData(arg string) {
// We have recipients, go to accept data
c.WriteResponse(354, EnhancedCode{2, 0, 0}, "Go ahead. End your data with <CR><LF>.<CR><LF>")

defer c.reset()

if c.server.LMTP {
c.handleDataLMTP()
c.reset()
return
}

r := newDataReader(c)
code, enhancedCode, msg := toSMTPStatus(c.Session().Data(r))
err := c.Session().Data(r)
code, enhancedCode, msg := toSMTPStatus(err)
if err == ErrDataTimeout {
// don't copy the data, write response and close the connection
c.WriteResponse(code, enhancedCode, msg)
c.reset()
c.Close()
return
}

r.limited = false
io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed
c.WriteResponse(code, enhancedCode, msg)
c.reset()
}

func (c *Conn) handleBdat(arg string) {
Expand Down
25 changes: 21 additions & 4 deletions data.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package smtp

import (
"bufio"
"io"
"net"
"time"
)

type EnhancedCode [3]int
Expand Down Expand Up @@ -42,8 +43,14 @@ var ErrDataTooLarge = &SMTPError{
Message: "Maximum message size exceeded",
}

var ErrDataTimeout = &SMTPError{
Code: 451,
EnhancedCode: EnhancedCode{4, 4, 2},
Message: "Timeout waiting for data from client",
}

type dataReader struct {
r *bufio.Reader
c *Conn
state int

limited bool
Expand All @@ -52,7 +59,7 @@ type dataReader struct {

func newDataReader(c *Conn) *dataReader {
dr := &dataReader{
r: c.text.R,
c: c,
}

if c.server.MaxMessageBytes > 0 {
Expand Down Expand Up @@ -87,12 +94,22 @@ func (r *dataReader) Read(b []byte) (n int, err error) {
stateEOF // reached .\r\n end marker line
)
for n < len(b) && r.state != stateEOF {
if r.c.server.ReadTimeout != 0 {
err = r.c.conn.SetReadDeadline(time.Now().Add(r.c.server.ReadTimeout))
if err != nil {
break
}
}
var c byte
c, err = r.r.ReadByte()
c, err = r.c.text.R.ReadByte()
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if e, ok := err.(net.Error); ok && e.Timeout() {
r.c.server.ErrorLog.Printf(r.c, "data read timeout: %w", err)
err = ErrDataTimeout
}
break
}
switch r.state {
Expand Down

0 comments on commit 358ef5a

Please sign in to comment.