Skip to content

Commit

Permalink
server: Implement CHUNKING extension (BDAT command)
Browse files Browse the repository at this point in the history
  • Loading branch information
foxcpp committed Jul 7, 2020
1 parent 451381b commit 1447748
Show file tree
Hide file tree
Showing 3 changed files with 489 additions and 22 deletions.
216 changes: 195 additions & 21 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ type Conn struct {
session Session
locker sync.Mutex

bdatPipe *io.PipeWriter
bdatStatus *statusCollector // used for BDAT on LMTP
dataResult chan error
bytesReceived int // counts total size of chunks when BDAT is used

fromReceived bool
recipients []string
}
Expand Down Expand Up @@ -133,6 +138,8 @@ func (c *Conn) handle(cmd string, arg string) {
case "RSET": // Reset session
c.reset()
c.WriteResponse(250, EnhancedCode{2, 0, 0}, "Session reset")
case "BDAT":
c.handleBdat(arg)
case "DATA":
c.handleData(arg)
case "QUIT":
Expand Down Expand Up @@ -169,9 +176,17 @@ func (c *Conn) SetSession(session Session) {
}

func (c *Conn) Close() error {
if session := c.Session(); session != nil {
session.Logout()
c.SetSession(nil)
c.locker.Lock()
defer c.locker.Unlock()

if c.bdatPipe != nil {
c.bdatPipe.CloseWithError(ErrDataReset)
c.bdatPipe = nil
}

if c.session != nil {
c.session.Logout()
c.session = nil
}

return c.conn.Close()
Expand Down Expand Up @@ -604,7 +619,175 @@ func (c *Conn) handleData(arg string) {
r.limited = false
io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed
c.WriteResponse(code, enhancedCode, msg)
}

func (c *Conn) handleBdat(arg string) {
args := strings.Fields(arg)
if len(args) == 0 {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Missing chunk size argument")
return
}
if len(args) > 2 {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Too many arguments")
return
}

if !c.fromReceived || len(c.recipients) == 0 {
c.WriteResponse(502, EnhancedCode{5, 5, 1}, "Missing RCPT TO command.")
return
}

last := false
if len(args) == 2 {
if !strings.EqualFold(args[1], "LAST") {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Unknown BDAT argument")
return
}
last = true
}

// ParseUint instead of Atoi so we will not accept negative values.
size, err := strconv.ParseUint(args[0], 10, 32)
if err != nil {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Malformed size argument")
return
}

if c.server.MaxMessageBytes != 0 && c.bytesReceived+int(size) > c.server.MaxMessageBytes {
c.WriteResponse(552, EnhancedCode{5, 3, 4}, "Max message size exceeded")

// Discard chunk itself without passing it to backend.
io.Copy(ioutil.Discard, io.LimitReader(c.text.R, int64(size)))

c.reset()
return
}

if c.bdatStatus == nil && c.server.LMTP {
c.bdatStatus = c.createStatusCollector()
}

if c.bdatPipe == nil {
var r *io.PipeReader
r, c.bdatPipe = io.Pipe()

c.dataResult = make(chan error, 1)

go func() {
defer func() {
if err := recover(); err != nil {
c.handlePanic(err, c.bdatStatus)

c.dataResult <- errPanic
r.CloseWithError(errPanic)
}
}()

var err error
if !c.server.LMTP {
err = c.Session().Data(r)
} else {
lmtpSession, ok := c.Session().(LMTPSession)
if !ok {
err = c.Session().Data(r)
for _, rcpt := range c.recipients {
c.bdatStatus.SetStatus(rcpt, err)
}
} else {
err = lmtpSession.LMTPData(r, c.bdatStatus)
}
}

c.dataResult <- err
r.CloseWithError(err)
}()
}

chunk := io.LimitReader(c.text.R, int64(size))
_, err = io.Copy(c.bdatPipe, chunk)
if err != nil {
// Backend might return an error early using CloseWithError without consuming
// the whole chunk.
io.Copy(ioutil.Discard, chunk)

c.WriteResponse(toSMTPStatus(err))

if err == errPanic {
c.Close()
}

c.reset()
return
}

c.bytesReceived += int(size)

if last {
c.bdatPipe.Close()

err := <-c.dataResult

if c.server.LMTP {
c.bdatStatus.fillRemaining(err)
for i, rcpt := range c.recipients {
code, enchCode, msg := toSMTPStatus(<-c.bdatStatus.status[i])
c.WriteResponse(code, enchCode, "<"+rcpt+"> "+msg)
}
} else {
code, enhancedCode, msg := toSMTPStatus(err)
c.WriteResponse(code, enhancedCode, msg)
}

if err == errPanic {
c.Close()
return
}

c.reset()
} else {
c.WriteResponse(250, EnhancedCode{2, 0, 0}, "Continue")
}
}

// ErrDataReset is returned by Reader pased to Data function if client does not
// send another BDAT command and instead closes connection or issues RSET command.
var ErrDataReset = errors.New("smtp: message transmission aborted")

var errPanic = &SMTPError{
Code: 421,
EnhancedCode: EnhancedCode{4, 0, 0},
Message: "Internal server error",
}

func (c *Conn) handlePanic(err interface{}, status *statusCollector) {
if status != nil {
status.fillRemaining(errPanic)
}

stack := debug.Stack()
c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack)
}

func (c *Conn) createStatusCollector() *statusCollector {
rcptCounts := make(map[string]int, len(c.recipients))

status := &statusCollector{
statusMap: make(map[string]chan error, len(c.recipients)),
status: make([]chan error, 0, len(c.recipients)),
}
for _, rcpt := range c.recipients {
rcptCounts[rcpt]++
}
// Create channels with buffer sizes necessary to fit all
// statuses for a single recipient to avoid deadlocks.
for rcpt, count := range rcptCounts {
status.statusMap[rcpt] = make(chan error, count)
}
for _, rcpt := range c.recipients {
status.status = append(status.status, status.statusMap[rcpt])
}

return status
}

type statusCollector struct {
Expand Down Expand Up @@ -651,24 +834,7 @@ func (s *statusCollector) SetStatus(rcptTo string, err error) {

func (c *Conn) handleDataLMTP() {
r := newDataReader(c)

rcptCounts := make(map[string]int, len(c.recipients))

status := &statusCollector{
statusMap: make(map[string]chan error, len(c.recipients)),
status: make([]chan error, 0, len(c.recipients)),
}
for _, rcpt := range c.recipients {
rcptCounts[rcpt]++
}
// Create channels with buffer sizes necessary to fit all
// statuses for a single recipient to avoid deadlocks.
for rcpt, count := range rcptCounts {
status.statusMap[rcpt] = make(chan error, count)
}
for _, rcpt := range c.recipients {
status.status = append(status.status, status.statusMap[rcpt])
}
status := c.createStatusCollector()

done := make(chan bool, 1)

Expand Down Expand Up @@ -779,9 +945,17 @@ func (c *Conn) reset() {
c.locker.Lock()
defer c.locker.Unlock()

if c.bdatPipe != nil {
c.bdatPipe.CloseWithError(ErrDataReset)
c.bdatPipe = nil
}
c.bdatStatus = nil
c.bytesReceived = 0

if c.session != nil {
c.session.Reset()
}

c.fromReceived = false
c.recipients = nil
}
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func NewServer(be Backend) *Server {
Backend: be,
done: make(chan struct{}, 1),
ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags),
caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES"},
caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"},
auths: map[string]SaslServerFactory{
sasl.Plain: func(conn *Conn) sasl.Server {
return sasl.NewPlainServer(func(identity, username, password string) error {
Expand Down
Loading

0 comments on commit 1447748

Please sign in to comment.