Skip to content

Commit

Permalink
server: Implement per-recipient status for LMTP
Browse files Browse the repository at this point in the history
A separate interface is used so backend implementation will be aware
of LMTP usage. This also avoids clutter in simple backend
implementations that don't support LMTP.

RFC 2033 (Section 4.2) says that we should send multiple status lines if
the same address was specified multiple times. We move responsibility of
doing this to the backend implementation and don't do any deduplication
themselves. This allows transparent forwarding for LMTP to be
implemented correctly.
  • Loading branch information
foxcpp authored and emersion committed Nov 20, 2019
1 parent a1945ce commit b42ac39
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 72 deletions.
20 changes: 20 additions & 0 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,23 @@ type Session interface {
// Set currently processed message contents and send it.
Data(r io.Reader) error
}

type LMTPSession interface {
// LMTPData is the LMTP-specific version of Data method.
// It can be optionally implemented by the backend to provide
// per-recipient status information when it is used over LMTP
// protocol.
//
// LMTPData implementation sets status information using passed
// StatusCollector by calling SetStatus once per each AddRcpt
// call, even if AddRcpt was called multiple times with
// the same argument.
//
// Return value of LMTPData itself is used as a status for
// recipients that got no status set before using StatusCollector.
LMTPData(r io.Reader, status StatusCollector) error
}

type StatusCollector interface {
SetStatus(rcptTo string, err error)
}
129 changes: 105 additions & 24 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,40 +488,121 @@ func (c *Conn) handleData(arg string) {
// We have recipients, go to accept data
c.WriteResponse(354, EnhancedCode{2, 0, 0}, "Go ahead. End your data with <CR><LF>.<CR><LF>")

var (
code int
enhancedCode EnhancedCode
msg string
)
if c.server.LMTP {
c.handleDataLMTP()
return
}

r := newDataReader(c)
err := c.Session().Data(r)
code, enhancedCode, msg := toSMTPStatus(c.Session().Data(r))
io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed
if err != nil {
if smtperr, ok := err.(*SMTPError); ok {
code = smtperr.Code
enhancedCode = smtperr.EnhancedCode
msg = smtperr.Message
} else {
code = 554
enhancedCode = EnhancedCode{5, 0, 0}
msg = "Error: transaction failed, blame it on the weather: " + err.Error()
c.WriteResponse(code, enhancedCode, msg)

c.reset()
}

type statusCollector struct {
mapLock sync.Mutex

// Contains map from recipient to list of channels that are used for that
// recipient.
//
// First SetStatus call uses first channel, second - second, etc. Channels
// that are already used are set to nil (otherwise we can accidentally
// reuse channel that was consumed by handleDataLMTP already).
statusMap map[string][]chan error

// Contains channels from statusMap, in the same
// order as Conn.recipients.
status []chan error
}

// fillRemaining sets status for all recipients SetStatus was not called for before.
func (s statusCollector) fillRemaining(err error) {
s.mapLock.Lock()
defer s.mapLock.Unlock()

// Since used channels in statusMap are set to nil, we can simply send
// on all non-nil channels to fill statuses not set by LMTPData.
for _, chList := range s.statusMap {
for _, ch := range chList {
if ch == nil {
continue
}
ch <- err
}
} else {
code = 250
enhancedCode = EnhancedCode{2, 0, 0}
msg = "OK: queued"
}
}

if c.server.LMTP {
// TODO: support per-recipient responses
func (s statusCollector) SetStatus(rcptTo string, err error) {
s.mapLock.Lock()
defer s.mapLock.Unlock()

chList := s.statusMap[rcptTo]
if chList == nil {
panic("SetStatus is called for recipient that was not specified before")
}

// Pick the first non-nil channel from list.
var usedCh chan error
for i, ch := range chList {
if ch != nil {
usedCh = ch
chList[i] = nil
break
}
}
if usedCh == nil {
panic("SetStatus is called more times than particular recipient was specified")
}

usedCh <- err
}

func (c *Conn) handleDataLMTP() {
r := newDataReader(c)

status := statusCollector{
statusMap: make(map[string][]chan error, len(c.recipients)),
status: make([]chan error, 0, len(c.recipients)),
}
for _, rcpt := range c.recipients {
ch := make(chan error, 1)
status.status = append(status.status, ch)
status.statusMap[rcpt] = append(status.statusMap[rcpt], ch)
}

lmtpSession, ok := c.Session().(LMTPSession)
if !ok {
// Fallback to using a single status for all recipients.
err := c.Session().Data(r)
io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed
for _, rcpt := range c.recipients {
c.WriteResponse(code, enhancedCode, "<"+rcpt+"> "+msg)
status.SetStatus(rcpt, err)
}
} else {
c.WriteResponse(code, enhancedCode, msg)
go func() {
status.fillRemaining(lmtpSession.LMTPData(r, status))
io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed
}()
}

c.reset()
for i, rcpt := range c.recipients {
code, enchCode, msg := toSMTPStatus(<-status.status[i])
c.WriteResponse(code, enchCode, "<"+rcpt+"> "+msg)
}
}

func toSMTPStatus(err error) (code int, enchCode EnhancedCode, msg string) {
if err != nil {
if smtperr, ok := err.(*SMTPError); ok {
return smtperr.Code, smtperr.EnhancedCode, smtperr.Message
} else {
return 554, EnhancedCode{5, 0, 0}, "Error: transaction failed, blame it on the weather: " + err.Error()
}
}

return 250, EnhancedCode{2, 0, 0}, "OK: queued"
}

func (c *Conn) Reject() {
Expand Down
197 changes: 197 additions & 0 deletions lmtp_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package smtp_test

import (
"bufio"
"errors"
"io"
"strings"
"testing"

"github.com/emersion/go-smtp"
)

func sendDeliveryCmdsLMTP(t *testing.T, scanner *bufio.Scanner, c io.Writer) {
sendLHLO(t, scanner, c)

io.WriteString(c, "MAIL FROM:<[email protected]>\r\n")
scanner.Scan()
io.WriteString(c, "RCPT TO:<[email protected]>\r\n")
scanner.Scan()
io.WriteString(c, "RCPT TO:<[email protected]>\r\n")
scanner.Scan()
io.WriteString(c, "DATA\r\n")
scanner.Scan()
io.WriteString(c, "Hey <3\r\n")
io.WriteString(c, ".\r\n")
}

func sendLHLO(t *testing.T, scanner *bufio.Scanner, c io.Writer) {
io.WriteString(c, "LHLO localhost\r\n")
scanner.Scan()
if scanner.Text() != "250-Hello localhost" {
t.Fatal("Invalid LHLO response:", scanner.Text())
}
for scanner.Scan() {
s := scanner.Text()

if strings.HasPrefix(s, "250 ") {
break
} else if !strings.HasPrefix(s, "250-") {
t.Fatal("Invalid capability response:", s)
}
}
}

func TestServer_LMTP(t *testing.T) {
be, s, c, scanner := testServerGreeted(t, func(s *smtp.Server) {
s.LMTP = true
be := s.Backend.(*backend)
be.implementLMTPData = true
be.lmtpStatus = []struct {
addr string
err error
}{
{"[email protected]", errors.New("nah")},
{"[email protected]", nil},
}
})
defer s.Close()
defer c.Close()

sendDeliveryCmdsLMTP(t, scanner, c)

scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "554 5.0.0 <[email protected]>") {
t.Fatal("Invalid DATA first response:", scanner.Text())
}
scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "250 ") {
t.Fatal("Invalid DATA second response:", scanner.Text())
}

if len(be.messages) != 0 || len(be.anonmsgs) != 1 {
t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs)
}
}

func TestServer_LMTP_Early(t *testing.T) {
// This test confirms responses are sent as early as possible
// e.g. right after SetStatus is called.

lmtpStatusSync := make(chan struct{})

be, s, c, scanner := testServerGreeted(t, func(s *smtp.Server) {
s.LMTP = true
be := s.Backend.(*backend)
be.implementLMTPData = true
be.lmtpStatusSync = lmtpStatusSync
be.lmtpStatus = []struct {
addr string
err error
}{
{"[email protected]", errors.New("nah")},
{"[email protected]", nil},
}
})
defer s.Close()
defer c.Close()

sendDeliveryCmdsLMTP(t, scanner, c)

// Test backend sends to sync channel after calling SetStatus.

scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "554 5.0.0 <[email protected]>") {
t.Fatal("Invalid DATA first response:", scanner.Text())
}

<-be.lmtpStatusSync

scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "250 ") {
t.Fatal("Invalid DATA second response:", scanner.Text())
}

<-be.lmtpStatusSync

if len(be.messages) != 0 || len(be.anonmsgs) != 1 {
t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs)
}
}

func TestServer_LMTP_Expand(t *testing.T) {
// This checks whether handleDataLMTP
// correctly expands results if backend doesn't
// implement LMTPSession.

be, s, c, scanner := testServerGreeted(t, func(s *smtp.Server) {
s.LMTP = true
})
defer s.Close()
defer c.Close()

sendDeliveryCmdsLMTP(t, scanner, c)

scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "250 ") {
t.Fatal("Invalid DATA first response:", scanner.Text())
}
scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "250 ") {
t.Fatal("Invalid DATA second response:", scanner.Text())
}

if len(be.messages) != 0 || len(be.anonmsgs) != 1 {
t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs)
}
}

func TestServer_LMTP_DuplicatedRcpt(t *testing.T) {
be, s, c, scanner := testServerGreeted(t, func(s *smtp.Server) {
s.LMTP = true
be := s.Backend.(*backend)
be.implementLMTPData = true
be.lmtpStatus = []struct {
addr string
err error
}{
{"[email protected]", &smtp.SMTPError{Code: 555}},
{"[email protected]", nil},
{"[email protected]", &smtp.SMTPError{Code: 556}},
}
})
defer s.Close()
defer c.Close()

sendLHLO(t, scanner, c)

io.WriteString(c, "MAIL FROM:<[email protected]>\r\n")
scanner.Scan()
io.WriteString(c, "RCPT TO:<[email protected]>\r\n")
scanner.Scan()
io.WriteString(c, "RCPT TO:<[email protected]>\r\n")
scanner.Scan()
io.WriteString(c, "RCPT TO:<[email protected]>\r\n")
scanner.Scan()
io.WriteString(c, "DATA\r\n")
scanner.Scan()
io.WriteString(c, "Hey <3\r\n")
io.WriteString(c, ".\r\n")

scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "555 5.0.0 <[email protected]>") {
t.Fatal("Invalid DATA first response:", scanner.Text())
}
scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "250 ") {
t.Fatal("Invalid DATA second response:", scanner.Text())
}
scanner.Scan()
if !strings.HasPrefix(scanner.Text(), "556 5.0.0 <[email protected]>") {
t.Fatal("Invalid DATA first response:", scanner.Text())
}

if len(be.messages) != 0 || len(be.anonmsgs) != 1 {
t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs)
}
}
Loading

0 comments on commit b42ac39

Please sign in to comment.