From 9920cf686b6c8f291bc3de0df78e759bb8dac5be Mon Sep 17 00:00:00 2001 From: kayrus Date: Wed, 17 Aug 2022 13:35:28 +0200 Subject: [PATCH] Properly handle DATA read timeouts --- conn.go | 15 ++++++++++++--- data.go | 25 +++++++++++++++++++++---- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index 72a67d8..3f926e2 100644 --- a/conn.go +++ b/conn.go @@ -628,18 +628,27 @@ func (c *Conn) handleData(arg string) { // We have recipients, go to accept data c.writeResponse(354, EnhancedCode{2, 0, 0}, "Go ahead. End your data with .") - defer c.reset() - if c.server.LMTP { c.handleDataLMTP() + c.reset() return } r := newDataReader(c) - code, enhancedCode, msg := toSMTPStatus(c.Session().Data(r)) + err := c.Session().Data(r) + code, enhancedCode, msg := toSMTPStatus(err) + if err == ErrDataTimeout { + // don't copy the data, write response and close the connection + c.writeResponse(code, enhancedCode, msg) + c.reset() + c.Close() + return + } + r.limited = false io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed c.writeResponse(code, enhancedCode, msg) + c.reset() } func (c *Conn) handleBdat(arg string) { diff --git a/data.go b/data.go index c338455..9d43fb7 100644 --- a/data.go +++ b/data.go @@ -1,8 +1,9 @@ package smtp import ( - "bufio" "io" + "net" + "time" ) type EnhancedCode [3]int @@ -42,8 +43,14 @@ var ErrDataTooLarge = &SMTPError{ Message: "Maximum message size exceeded", } +var ErrDataTimeout = &SMTPError{ + Code: 451, + EnhancedCode: EnhancedCode{4, 4, 2}, + Message: "Timeout waiting for data from client", +} + type dataReader struct { - r *bufio.Reader + c *Conn state int limited bool @@ -52,7 +59,7 @@ type dataReader struct { func newDataReader(c *Conn) *dataReader { dr := &dataReader{ - r: c.text.R, + c: c, } if c.server.MaxMessageBytes > 0 { @@ -87,8 +94,18 @@ func (r *dataReader) Read(b []byte) (n int, err error) { stateEOF // reached .\r\n end marker line ) for n < len(b) && r.state != stateEOF { + if r.c.server.ReadTimeout != 0 { + err = r.c.conn.SetReadDeadline(time.Now().Add(r.c.server.ReadTimeout)) + if err != nil { + break + } + if e, ok := err.(net.Error); ok && e.Timeout() { + r.c.server.ErrorLog.Printf("data read timeout: %w", err) + err = ErrDataTimeout + } + } var c byte - c, err = r.r.ReadByte() + c, err = r.c.text.R.ReadByte() if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF