Skip to content

Commit

Permalink
server:Implement CHUNKING extension (BDAT command)
Browse files Browse the repository at this point in the history
  • Loading branch information
foxcpp committed Jun 23, 2020
1 parent 451381b commit d9804b3
Show file tree
Hide file tree
Showing 3 changed files with 553 additions and 22 deletions.
280 changes: 259 additions & 21 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ type Conn struct {
session Session
locker sync.Mutex

bdatPipe *io.PipeWriter
bdatStatus *statusCollector // used for BDAT on LMTP
dataResult chan error
bytesReceived int // counts total size of chunks when BDAT is used

fromReceived bool
recipients []string
}
Expand Down Expand Up @@ -133,6 +138,8 @@ func (c *Conn) handle(cmd string, arg string) {
case "RSET": // Reset session
c.reset()
c.WriteResponse(250, EnhancedCode{2, 0, 0}, "Session reset")
case "BDAT":
c.handleBdat(arg)
case "DATA":
c.handleData(arg)
case "QUIT":
Expand Down Expand Up @@ -169,9 +176,17 @@ func (c *Conn) SetSession(session Session) {
}

func (c *Conn) Close() error {
if session := c.Session(); session != nil {
session.Logout()
c.SetSession(nil)
c.locker.Lock()
defer c.locker.Unlock()

if c.bdatPipe != nil {
c.bdatPipe.CloseWithError(ErrDataReset)
c.bdatPipe = nil
}

if c.session != nil {
c.session.Logout()
c.session = nil
}

return c.conn.Close()
Expand Down Expand Up @@ -604,7 +619,239 @@ func (c *Conn) handleData(arg string) {
r.limited = false
io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed
c.WriteResponse(code, enhancedCode, msg)
}

// This is a version of io.CopyN with a subtle difference that it returns the number
// of messages read from src and not the amount of written bytes.
//
// This subtle difference is important for use with io.PipeWriter as a destination
// in handleBdat.
func copyN(dst io.Writer, src io.Reader, n int64) (consumed int64, err error) {
src = io.LimitReader(src, n)

size := 32 * 1024
if l, ok := src.(*io.LimitedReader); ok && int64(size) > l.N {
if l.N < 1 {
size = 1
} else {
size = int(l.N)
}
}
buf := make([]byte, size)

for {
nr, er := src.Read(buf)
if nr > 0 {
consumed += int64(nr)
nw, ew := dst.Write(buf[0:nr])
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}

if consumed < n && err == nil {
// src stopped early; must have been EOF.
err = io.EOF
}

return consumed, err
}

func (c *Conn) handleBdat(arg string) {
args := strings.Split(arg, " ")
if len(args) == 0 {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Missing chunk size argument")
return
}

if !c.fromReceived || len(c.recipients) == 0 {
c.WriteResponse(502, EnhancedCode{5, 5, 1}, "Missing RCPT TO command.")
return
}

last := false
if len(args) == 2 {
if !strings.EqualFold(args[1], "LAST") {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Unknown BDAT argument")
return
}
last = true
}

// ParseUint instead of Atoi so we will not accept negative values.
size, err := strconv.ParseUint(args[0], 10, 32)
if err != nil {
c.WriteResponse(501, EnhancedCode{5, 5, 4}, "Malformed size argument")
return
}

if c.bdatStatus == nil && c.server.LMTP {
c.bdatStatus = c.createStatusCollector()
}

if c.bdatPipe == nil {
var r *io.PipeReader
r, c.bdatPipe = io.Pipe()

c.dataResult = make(chan error)

go func() {
defer func() {
if err := recover(); err != nil {
c.handlePanic(err, c.bdatStatus)

select {
case <-c.dataResult:
c.dataResult <- errPanic
default:
}

r.CloseWithError(errPanic)
}
}()

var err error
if !c.server.LMTP {
err = c.Session().Data(r)
} else {
lmtpSession, ok := c.Session().(LMTPSession)
if !ok {
err = c.Session().Data(r)
for _, rcpt := range c.recipients {
c.bdatStatus.SetStatus(rcpt, err)
}
} else {
err = lmtpSession.LMTPData(r, c.bdatStatus)
}
}

select {
case <-c.dataResult:
// handleBdat goroutine wants us to sent an error here.
c.dataResult <- err
default:
}

if err != nil {
r.CloseWithError(err)
} else {
r.Close()
}
}()
}

// Using copyN that returns amount of bytes read from c.text.R instead of
// amount of bytes written is important here because read may succeed and
// write not in case of early error returned by Session.Data. In this case
// io.CopyN call below needs to calculate correct amount of bytes to
// discard.
consumed, err := copyN(c.bdatPipe, c.text.R, int64(size))
if err != nil {
// Backend might return an error without consuming the whole chunk.
io.CopyN(ioutil.Discard, c.text.R, int64(size)-consumed)

code, enhancedCode, msg := toSMTPStatus(err)
c.WriteResponse(code, enhancedCode, msg)

if err == errPanic {
c.Close()
}

return
}

c.bytesReceived += int(consumed)
if c.server.MaxMessageBytes != 0 && c.bytesReceived > c.server.MaxMessageBytes {
c.WriteResponse(552, EnhancedCode{5, 3, 4}, "Max message size exceeded")
return
}

if last {
c.bdatPipe.Close()

// At this point Session.Data might still be running and processing the
// message we need to a) wait for it before killing the session b)
// obtain the error if any.
//
// io.Pipe CloseWithError cannot be used at this point since we are not
// writing anymore. Therefore we signal that we want an error sent to this
// channel and then wait for it.
c.dataResult <- nil
err := <-c.dataResult

if c.server.LMTP {
c.bdatStatus.fillRemaining(err)
for i, rcpt := range c.recipients {
code, enchCode, msg := toSMTPStatus(<-c.bdatStatus.status[i])
c.WriteResponse(code, enchCode, "<"+rcpt+"> "+msg)
}
} else {
code, enhancedCode, msg := toSMTPStatus(err)
c.WriteResponse(code, enhancedCode, msg)
}

if err == errPanic {
c.Close()
return
}

c.reset()
} else {
c.WriteResponse(250, EnhancedCode{2, 0, 0}, "Continue")
}
}

// ErrDataReset is returned by Reader pased to Data function if client does not
// send another BDAT command and instead closes connection or issues RSET command.
var ErrDataReset = errors.New("smtp: message transmission aborted")

var errPanic = &SMTPError{
Code: 421,
EnhancedCode: EnhancedCode{4, 0, 0},
Message: "Internal server error",
}

func (c *Conn) handlePanic(err interface{}, status *statusCollector) {
if status != nil {
status.fillRemaining(errPanic)
}

stack := debug.Stack()
c.server.ErrorLog.Printf("panic serving %v: %v\n%s", c.State().RemoteAddr, err, stack)
}

func (c *Conn) createStatusCollector() *statusCollector {
rcptCounts := make(map[string]int, len(c.recipients))

status := &statusCollector{
statusMap: make(map[string]chan error, len(c.recipients)),
status: make([]chan error, 0, len(c.recipients)),
}
for _, rcpt := range c.recipients {
rcptCounts[rcpt]++
}
// Create channels with buffer sizes necessary to fit all
// statuses for a single recipient to avoid deadlocks.
for rcpt, count := range rcptCounts {
status.statusMap[rcpt] = make(chan error, count)
}
for _, rcpt := range c.recipients {
status.status = append(status.status, status.statusMap[rcpt])
}

return status
}

type statusCollector struct {
Expand Down Expand Up @@ -651,24 +898,7 @@ func (s *statusCollector) SetStatus(rcptTo string, err error) {

func (c *Conn) handleDataLMTP() {
r := newDataReader(c)

rcptCounts := make(map[string]int, len(c.recipients))

status := &statusCollector{
statusMap: make(map[string]chan error, len(c.recipients)),
status: make([]chan error, 0, len(c.recipients)),
}
for _, rcpt := range c.recipients {
rcptCounts[rcpt]++
}
// Create channels with buffer sizes necessary to fit all
// statuses for a single recipient to avoid deadlocks.
for rcpt, count := range rcptCounts {
status.statusMap[rcpt] = make(chan error, count)
}
for _, rcpt := range c.recipients {
status.status = append(status.status, status.statusMap[rcpt])
}
status := c.createStatusCollector()

done := make(chan bool, 1)

Expand Down Expand Up @@ -779,9 +1009,17 @@ func (c *Conn) reset() {
c.locker.Lock()
defer c.locker.Unlock()

if c.bdatPipe != nil {
c.bdatPipe.CloseWithError(ErrDataReset)
c.bdatPipe = nil
}
c.bdatStatus = nil
c.bytesReceived = 0

if c.session != nil {
c.session.Reset()
}

c.fromReceived = false
c.recipients = nil
}
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func NewServer(be Backend) *Server {
Backend: be,
done: make(chan struct{}, 1),
ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags),
caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES"},
caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"},
auths: map[string]SaslServerFactory{
sasl.Plain: func(conn *Conn) sasl.Server {
return sasl.NewPlainServer(func(identity, username, password string) error {
Expand Down
Loading

0 comments on commit d9804b3

Please sign in to comment.