diff --git a/backend.go b/backend.go index 3368f2d..fa85ddf 100644 --- a/backend.go +++ b/backend.go @@ -52,3 +52,23 @@ type Session interface { // Set currently processed message contents and send it. Data(r io.Reader) error } + +type LMTPSession interface { + // LMTPData is the LMTP-specific version of Data method. + // It can be optionally implemented by the backend to provide + // per-recipient status information when it is used over LMTP + // protocol. + // + // LMTPData implementation sets status information using passed + // StatusCollector by calling SetStatus once per each AddRcpt + // call, even if AddRcpt was called multiple times with + // the same argument. + // + // Return value of LMTPData itself is used as a status for + // recipients that got no status set before using StatusCollector. + LMTPData(r io.Reader, status StatusCollector) error +} + +type StatusCollector interface { + SetStatus(rcptTo string, err error) +} diff --git a/conn.go b/conn.go index 0c55e06..f404cae 100644 --- a/conn.go +++ b/conn.go @@ -488,40 +488,121 @@ func (c *Conn) handleData(arg string) { // We have recipients, go to accept data c.WriteResponse(354, EnhancedCode{2, 0, 0}, "Go ahead. End your data with .") - var ( - code int - enhancedCode EnhancedCode - msg string - ) + if c.server.LMTP { + c.handleDataLMTP() + return + } + r := newDataReader(c) - err := c.Session().Data(r) + code, enhancedCode, msg := toSMTPStatus(c.Session().Data(r)) io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed - if err != nil { - if smtperr, ok := err.(*SMTPError); ok { - code = smtperr.Code - enhancedCode = smtperr.EnhancedCode - msg = smtperr.Message - } else { - code = 554 - enhancedCode = EnhancedCode{5, 0, 0} - msg = "Error: transaction failed, blame it on the weather: " + err.Error() + c.WriteResponse(code, enhancedCode, msg) + + c.reset() +} + +type statusCollector struct { + mapLock sync.Mutex + + // Contains map from recipient to list of channels that are used for that + // recipient. + // + // First SetStatus call uses first channel, second - second, etc. Channels + // that are already used are set to nil (otherwise we can accidentally + // reuse channel that was consumed by handleDataLMTP already). + statusMap map[string][]chan error + + // Contains channels from statusMap, in the same + // order as Conn.recipients. + status []chan error +} + +// fillRemaining sets status for all recipients SetStatus was not called for before. +func (s statusCollector) fillRemaining(err error) { + s.mapLock.Lock() + defer s.mapLock.Unlock() + + // Since used channels in statusMap are set to nil, we can simply send + // on all non-nil channels to fill statuses not set by LMTPData. + for _, chList := range s.statusMap { + for _, ch := range chList { + if ch == nil { + continue + } + ch <- err } - } else { - code = 250 - enhancedCode = EnhancedCode{2, 0, 0} - msg = "OK: queued" } +} - if c.server.LMTP { - // TODO: support per-recipient responses +func (s statusCollector) SetStatus(rcptTo string, err error) { + s.mapLock.Lock() + defer s.mapLock.Unlock() + + chList := s.statusMap[rcptTo] + if chList == nil { + panic("SetStatus is called for recipient that was not specified before") + } + + // Pick the first non-nil channel from list. + var usedCh chan error + for i, ch := range chList { + if ch != nil { + usedCh = ch + chList[i] = nil + break + } + } + if usedCh == nil { + panic("SetStatus is called more times than particular recipient was specified") + } + + usedCh <- err +} + +func (c *Conn) handleDataLMTP() { + r := newDataReader(c) + + status := statusCollector{ + statusMap: make(map[string][]chan error, len(c.recipients)), + status: make([]chan error, 0, len(c.recipients)), + } + for _, rcpt := range c.recipients { + ch := make(chan error, 1) + status.status = append(status.status, ch) + status.statusMap[rcpt] = append(status.statusMap[rcpt], ch) + } + + lmtpSession, ok := c.Session().(LMTPSession) + if !ok { + // Fallback to using a single status for all recipients. + err := c.Session().Data(r) + io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed for _, rcpt := range c.recipients { - c.WriteResponse(code, enhancedCode, "<"+rcpt+"> "+msg) + status.SetStatus(rcpt, err) } } else { - c.WriteResponse(code, enhancedCode, msg) + go func() { + status.fillRemaining(lmtpSession.LMTPData(r, status)) + io.Copy(ioutil.Discard, r) // Make sure all the data has been consumed + }() } - c.reset() + for i, rcpt := range c.recipients { + code, enchCode, msg := toSMTPStatus(<-status.status[i]) + c.WriteResponse(code, enchCode, "<"+rcpt+"> "+msg) + } +} + +func toSMTPStatus(err error) (code int, enchCode EnhancedCode, msg string) { + if err != nil { + if smtperr, ok := err.(*SMTPError); ok { + return smtperr.Code, smtperr.EnhancedCode, smtperr.Message + } else { + return 554, EnhancedCode{5, 0, 0}, "Error: transaction failed, blame it on the weather: " + err.Error() + } + } + + return 250, EnhancedCode{2, 0, 0}, "OK: queued" } func (c *Conn) Reject() { diff --git a/lmtp_server_test.go b/lmtp_server_test.go new file mode 100644 index 0000000..9cd98c2 --- /dev/null +++ b/lmtp_server_test.go @@ -0,0 +1,197 @@ +package smtp_test + +import ( + "bufio" + "errors" + "io" + "strings" + "testing" + + "github.com/emersion/go-smtp" +) + +func sendDeliveryCmdsLMTP(t *testing.T, scanner *bufio.Scanner, c io.Writer) { + sendLHLO(t, scanner, c) + + io.WriteString(c, "MAIL FROM:\r\n") + scanner.Scan() + io.WriteString(c, "RCPT TO:\r\n") + scanner.Scan() + io.WriteString(c, "RCPT TO:\r\n") + scanner.Scan() + io.WriteString(c, "DATA\r\n") + scanner.Scan() + io.WriteString(c, "Hey <3\r\n") + io.WriteString(c, ".\r\n") +} + +func sendLHLO(t *testing.T, scanner *bufio.Scanner, c io.Writer) { + io.WriteString(c, "LHLO localhost\r\n") + scanner.Scan() + if scanner.Text() != "250-Hello localhost" { + t.Fatal("Invalid LHLO response:", scanner.Text()) + } + for scanner.Scan() { + s := scanner.Text() + + if strings.HasPrefix(s, "250 ") { + break + } else if !strings.HasPrefix(s, "250-") { + t.Fatal("Invalid capability response:", s) + } + } +} + +func TestServer_LMTP(t *testing.T) { + be, s, c, scanner := testServerGreeted(t, func(s *smtp.Server) { + s.LMTP = true + be := s.Backend.(*backend) + be.implementLMTPData = true + be.lmtpStatus = []struct { + addr string + err error + }{ + {"root@gchq.gov.uk", errors.New("nah")}, + {"root@bnd.bund.de", nil}, + } + }) + defer s.Close() + defer c.Close() + + sendDeliveryCmdsLMTP(t, scanner, c) + + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "554 5.0.0 ") { + t.Fatal("Invalid DATA first response:", scanner.Text()) + } + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "250 ") { + t.Fatal("Invalid DATA second response:", scanner.Text()) + } + + if len(be.messages) != 0 || len(be.anonmsgs) != 1 { + t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs) + } +} + +func TestServer_LMTP_Early(t *testing.T) { + // This test confirms responses are sent as early as possible + // e.g. right after SetStatus is called. + + lmtpStatusSync := make(chan struct{}) + + be, s, c, scanner := testServerGreeted(t, func(s *smtp.Server) { + s.LMTP = true + be := s.Backend.(*backend) + be.implementLMTPData = true + be.lmtpStatusSync = lmtpStatusSync + be.lmtpStatus = []struct { + addr string + err error + }{ + {"root@gchq.gov.uk", errors.New("nah")}, + {"root@bnd.bund.de", nil}, + } + }) + defer s.Close() + defer c.Close() + + sendDeliveryCmdsLMTP(t, scanner, c) + + // Test backend sends to sync channel after calling SetStatus. + + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "554 5.0.0 ") { + t.Fatal("Invalid DATA first response:", scanner.Text()) + } + + <-be.lmtpStatusSync + + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "250 ") { + t.Fatal("Invalid DATA second response:", scanner.Text()) + } + + <-be.lmtpStatusSync + + if len(be.messages) != 0 || len(be.anonmsgs) != 1 { + t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs) + } +} + +func TestServer_LMTP_Expand(t *testing.T) { + // This checks whether handleDataLMTP + // correctly expands results if backend doesn't + // implement LMTPSession. + + be, s, c, scanner := testServerGreeted(t, func(s *smtp.Server) { + s.LMTP = true + }) + defer s.Close() + defer c.Close() + + sendDeliveryCmdsLMTP(t, scanner, c) + + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "250 ") { + t.Fatal("Invalid DATA first response:", scanner.Text()) + } + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "250 ") { + t.Fatal("Invalid DATA second response:", scanner.Text()) + } + + if len(be.messages) != 0 || len(be.anonmsgs) != 1 { + t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs) + } +} + +func TestServer_LMTP_DuplicatedRcpt(t *testing.T) { + be, s, c, scanner := testServerGreeted(t, func(s *smtp.Server) { + s.LMTP = true + be := s.Backend.(*backend) + be.implementLMTPData = true + be.lmtpStatus = []struct { + addr string + err error + }{ + {"root@gchq.gov.uk", &smtp.SMTPError{Code: 555}}, + {"root@bnd.bund.de", nil}, + {"root@gchq.gov.uk", &smtp.SMTPError{Code: 556}}, + } + }) + defer s.Close() + defer c.Close() + + sendLHLO(t, scanner, c) + + io.WriteString(c, "MAIL FROM:\r\n") + scanner.Scan() + io.WriteString(c, "RCPT TO:\r\n") + scanner.Scan() + io.WriteString(c, "RCPT TO:\r\n") + scanner.Scan() + io.WriteString(c, "RCPT TO:\r\n") + scanner.Scan() + io.WriteString(c, "DATA\r\n") + scanner.Scan() + io.WriteString(c, "Hey <3\r\n") + io.WriteString(c, ".\r\n") + + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "555 5.0.0 ") { + t.Fatal("Invalid DATA first response:", scanner.Text()) + } + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "250 ") { + t.Fatal("Invalid DATA second response:", scanner.Text()) + } + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "556 5.0.0 ") { + t.Fatal("Invalid DATA first response:", scanner.Text()) + } + + if len(be.messages) != 0 || len(be.anonmsgs) != 1 { + t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs) + } +} diff --git a/server_test.go b/server_test.go index 8a9e57b..b684f48 100644 --- a/server_test.go +++ b/server_test.go @@ -23,6 +23,13 @@ type backend struct { messages []*message anonmsgs []*message + implementLMTPData bool + lmtpStatus []struct { + addr string + err error + } + lmtpStatusSync chan struct{} + panicOnMail bool userErr error } @@ -35,6 +42,11 @@ func (be *backend) Login(_ *smtp.ConnectionState, username, password string) (sm if username != "username" || password != "password" { return nil, errors.New("Invalid username or password") } + + if be.implementLMTPData { + return &lmtpSession{&session{backend: be}}, nil + } + return &session{backend: be}, nil } @@ -43,9 +55,17 @@ func (be *backend) AnonymousLogin(_ *smtp.ConnectionState) (smtp.Session, error) return &session{}, be.userErr } + if be.implementLMTPData { + return &lmtpSession{&session{backend: be, anonymous: true}}, nil + } + return &session{backend: be, anonymous: true}, nil } +type lmtpSession struct { + *session +} + type session struct { backend *backend anonymous bool @@ -89,6 +109,22 @@ func (s *session) Data(r io.Reader) error { return nil } +func (s *session) LMTPData(r io.Reader, collector smtp.StatusCollector) error { + if err := s.Data(r); err != nil { + return err + } + + for _, val := range s.backend.lmtpStatus { + collector.SetStatus(val.addr, val.err) + + if s.backend.lmtpStatusSync != nil { + s.backend.lmtpStatusSync <- struct{}{} + } + } + + return nil +} + type serverConfigureFunc func(*smtp.Server) var ( @@ -585,51 +621,3 @@ func TestStrictServerBad(t *testing.T) { t.Fatal("Invalid MAIL response:", scanner.Text()) } } - -func TestServer_lmtpOK(t *testing.T) { - be, s, c, scanner := testServerGreeted(t, func(s *smtp.Server) { - s.LMTP = true - }) - defer s.Close() - defer c.Close() - - io.WriteString(c, "LHLO localhost\r\n") - - scanner.Scan() - if scanner.Text() != "250-Hello localhost" { - t.Fatal("Invalid LHLO response:", scanner.Text()) - } - - for scanner.Scan() { - s := scanner.Text() - - if strings.HasPrefix(s, "250 ") { - break - } else if !strings.HasPrefix(s, "250-") { - t.Fatal("Invalid capability response:", s) - } - } - - io.WriteString(c, "MAIL FROM:\r\n") - scanner.Scan() - io.WriteString(c, "RCPT TO:\r\n") - scanner.Scan() - io.WriteString(c, "RCPT TO:\r\n") - scanner.Scan() - io.WriteString(c, "DATA\r\n") - scanner.Scan() - io.WriteString(c, "Hey <3\r\n") - io.WriteString(c, ".\r\n") - scanner.Scan() - - if !strings.HasPrefix(scanner.Text(), "250 ") { - t.Fatal("Invalid DATA first response:", scanner.Text()) - } - if !strings.HasPrefix(scanner.Text(), "250 ") { - t.Fatal("Invalid DATA second response:", scanner.Text()) - } - - if len(be.messages) != 0 || len(be.anonmsgs) != 1 { - t.Fatal("Invalid number of sent messages:", be.messages, be.anonmsgs) - } -}