diff --git a/README.md b/README.md index 2afa14b..a63838f 100644 --- a/README.md +++ b/README.md @@ -65,22 +65,20 @@ import ( // The Backend implements SMTP server methods. type Backend struct{} -// Login handles a login command with username and password. -func (bkd *Backend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { - if username != "username" || password != "password" { - return nil, errors.New("Invalid username or password") - } +func (bkd *Backend) NewSession(_ smtp.ConnectionState, _ string) (smtp.Session, error) { return &Session{}, nil } -// AnonymousLogin requires clients to authenticate using SMTP AUTH before sending emails -func (bkd *Backend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { - return nil, smtp.ErrAuthRequired -} - -// A Session is returned after successful login. +// A Session is returned after EHLO. type Session struct{} +func (s *Session) AuthPlain(username, password string) error { + if username != "username" || password != "password" { + return errors.New("Invalid username or password") + } + return nil +} + func (s *Session) Mail(from string, opts smtp.MailOptions) error { log.Println("Mail from:", from) return nil diff --git a/backend.go b/backend.go index 6a00873..e4de6af 100644 --- a/backend.go +++ b/backend.go @@ -1,24 +1,24 @@ package smtp import ( - "errors" "io" ) var ( - ErrAuthRequired = errors.New("Please authenticate first") - ErrAuthUnsupported = errors.New("Authentication not supported") -) + ErrAuthRequired = &SMTPError{ + Code: 502, + EnhancedCode: EnhancedCode{5, 7, 0}, + Message: "Please authenticate first", + } + ErrAuthUnsupported = &SMTPError{ + Code: 502, + EnhancedCode: EnhancedCode{5, 7, 0}, + Message: "Authentication not supported", + }) // A SMTP server backend. type Backend interface { - // Authenticate a user. Return smtp.ErrAuthUnsupported if you don't want to - // support this. - Login(state *ConnectionState, username, password string) (Session, error) - - // Called if the client attempts to send mail without logging in first. - // Return smtp.ErrAuthRequired if you don't want to support this. - AnonymousLogin(state *ConnectionState) (Session, error) + NewSession(c ConnectionState, hostname string) (Session, error) } type BodyType string @@ -68,6 +68,9 @@ type Session interface { // Free all resources associated with session. Logout() error + // Authenticate the user using SASL PLAIN. + AuthPlain(username, password string) error + // Set return path for currently processed message. Mail(from string, opts MailOptions) error // Add recipient for currently processed message. diff --git a/backendutil/transform.go b/backendutil/transform.go index a5ff666..55a7da2 100755 --- a/backendutil/transform.go +++ b/backendutil/transform.go @@ -15,22 +15,12 @@ type TransformBackend struct { TransformData func(r io.Reader) (io.Reader, error) } -// Login implements the smtp.Backend interface. -func (be *TransformBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { - s, err := be.Backend.Login(state, username, password) +func (be *TransformBackend) NewSession(c smtp.ConnectionState, hostname string) (smtp.Session, error) { + sess, err := be.Backend.NewSession(c, hostname) if err != nil { return nil, err } - return &transformSession{s, be}, nil -} - -// AnonymousLogin implements the smtp.Backend interface. -func (be *TransformBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { - s, err := be.Backend.AnonymousLogin(state) - if err != nil { - return nil, err - } - return &transformSession{s, be}, nil + return &transformSession{Session: sess, be: be}, nil } type transformSession struct { @@ -43,6 +33,10 @@ func (s *transformSession) Reset() { s.Session.Reset() } +func (s *transformSession) AuthPlain(username, password string) error { + return s.Session.AuthPlain(username, password) +} + func (s *transformSession) Mail(from string, opts smtp.MailOptions) error { if s.be.TransformMail != nil { var err error diff --git a/backendutil/transform_test.go b/backendutil/transform_test.go index 3c77638..5a85865 100755 --- a/backendutil/transform_test.go +++ b/backendutil/transform_test.go @@ -29,22 +29,7 @@ type backend struct { userErr error } -func (be *backend) Login(_ *smtp.ConnectionState, username, password string) (smtp.Session, error) { - if be.userErr != nil { - return &session{}, be.userErr - } - - if username != "username" || password != "password" { - return nil, errors.New("Invalid username or password") - } - return &session{backend: be}, nil -} - -func (be *backend) AnonymousLogin(_ *smtp.ConnectionState) (smtp.Session, error) { - if be.userErr != nil { - return &session{}, be.userErr - } - +func (be *backend) NewSession(c smtp.ConnectionState, hostname string) (smtp.Session, error) { return &session{backend: be, anonymous: true}, nil } @@ -63,7 +48,18 @@ func (s *session) Logout() error { return nil } +func (s *session) AuthPlain(username, password string) error { + if username != "username" || password != "password" { + return errors.New("Invalid username or password") + } + s.anonymous = false + return nil +} + func (s *session) Mail(from string, opts smtp.MailOptions) error { + if s.backend.userErr != nil { + return s.backend.userErr + } s.Reset() s.msg.From = from return nil diff --git a/cmd/smtp-debug-server/main.go b/cmd/smtp-debug-server/main.go index 28d9843..8c4c65b 100644 --- a/cmd/smtp-debug-server/main.go +++ b/cmd/smtp-debug-server/main.go @@ -17,16 +17,16 @@ func init() { type backend struct{} -func (bkd *backend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { - return &session{}, nil -} - -func (bkd *backend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { +func (bkd *backend) NewSession(c smtp.ConnectionState, hostname string) (smtp.Session, error) { return &session{}, nil } type session struct{} +func (s *session) AuthPlain(username, password string) error { + return nil +} + func (s *session) Mail(from string, opts smtp.MailOptions) error { return nil } diff --git a/conn.go b/conn.go index f033423..2e79465 100644 --- a/conn.go +++ b/conn.go @@ -235,56 +235,60 @@ func (c *Conn) protocolError(code int, ec EnhancedCode, msg string) { // GREET state -> waiting for HELO func (c *Conn) handleGreet(enhanced bool, arg string) { - if !enhanced { - domain, err := parseHelloArgument(arg) - if err != nil { - c.WriteResponse(501, EnhancedCode{5, 5, 2}, "Domain/address argument required for HELO") - return - } - c.helo = domain + domain, err := parseHelloArgument(arg) + if err != nil { + c.WriteResponse(501, EnhancedCode{5, 5, 2}, "Domain/address argument required for HELO") + return + } + c.helo = domain - c.WriteResponse(250, EnhancedCode{2, 0, 0}, fmt.Sprintf("Hello %s", domain)) - } else { - domain, err := parseHelloArgument(arg) - if err != nil { - c.WriteResponse(501, EnhancedCode{5, 5, 2}, "Domain/address argument required for EHLO") + sess, err := c.server.Backend.NewSession(c.State(), domain) + if err != nil { + if smtpErr, ok := err.(*SMTPError); ok { + c.WriteResponse(smtpErr.Code, smtpErr.EnhancedCode, smtpErr.Message) return } + c.WriteResponse(451, EnhancedCode{4, 0, 0}, err.Error()) + return + } + c.SetSession(sess) - c.helo = domain - - caps := []string{} - caps = append(caps, c.server.caps...) - if _, isTLS := c.TLSConnectionState(); c.server.TLSConfig != nil && !isTLS { - caps = append(caps, "STARTTLS") - } - if c.authAllowed() { - authCap := "AUTH" - for name := range c.server.auths { - authCap += " " + name - } + if !enhanced { + c.WriteResponse(250, EnhancedCode{2, 0, 0}, fmt.Sprintf("Hello %s", domain)) + return + } - caps = append(caps, authCap) - } - if c.server.EnableSMTPUTF8 { - caps = append(caps, "SMTPUTF8") - } - if _, isTLS := c.TLSConnectionState(); isTLS && c.server.EnableREQUIRETLS { - caps = append(caps, "REQUIRETLS") - } - if c.server.EnableBINARYMIME { - caps = append(caps, "BINARYMIME") - } - if c.server.MaxMessageBytes > 0 { - caps = append(caps, fmt.Sprintf("SIZE %v", c.server.MaxMessageBytes)) - } else { - caps = append(caps, "SIZE") + caps := []string{} + caps = append(caps, c.server.caps...) + if _, isTLS := c.TLSConnectionState(); c.server.TLSConfig != nil && !isTLS { + caps = append(caps, "STARTTLS") + } + if c.authAllowed() { + authCap := "AUTH" + for name := range c.server.auths { + authCap += " " + name } - args := []string{"Hello " + domain} - args = append(args, caps...) - c.WriteResponse(250, NoEnhancedCode, args...) + caps = append(caps, authCap) + } + if c.server.EnableSMTPUTF8 { + caps = append(caps, "SMTPUTF8") + } + if _, isTLS := c.TLSConnectionState(); isTLS && c.server.EnableREQUIRETLS { + caps = append(caps, "REQUIRETLS") + } + if c.server.EnableBINARYMIME { + caps = append(caps, "BINARYMIME") + } + if c.server.MaxMessageBytes > 0 { + caps = append(caps, fmt.Sprintf("SIZE %v", c.server.MaxMessageBytes)) + } else { + caps = append(caps, "SIZE") } + + args := []string{"Hello " + domain} + args = append(args, caps...) + c.WriteResponse(250, NoEnhancedCode, args...) } // READY state -> waiting for MAIL @@ -298,21 +302,6 @@ func (c *Conn) handleMail(arg string) { return } - if c.Session() == nil { - state := c.State() - session, err := c.server.Backend.AnonymousLogin(&state) - if err != nil { - if smtpErr, ok := err.(*SMTPError); ok { - c.WriteResponse(smtpErr.Code, smtpErr.EnhancedCode, smtpErr.Message) - } else { - c.WriteResponse(502, EnhancedCode{5, 7, 0}, err.Error()) - } - return - } - - c.SetSession(session) - } - if len(arg) < 6 || strings.ToUpper(arg[0:5]) != "FROM:" { c.WriteResponse(501, EnhancedCode{5, 5, 2}, "Was expecting MAIL arg syntax of FROM:
") return diff --git a/example_test.go b/example_test.go index 607c971..65a829b 100644 --- a/example_test.go +++ b/example_test.go @@ -92,22 +92,22 @@ func ExampleSendMail() { // The Backend implements SMTP server methods. type Backend struct{} -// Login handles a login command with username and password. -func (bkd *Backend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { - if username != "username" || password != "password" { - return nil, errors.New("Invalid username or password") - } +// NewSession is called after client greeting (EHLO, HELO). +func (bkd *Backend) NewSession(c smtp.ConnectionState, hostname string) (smtp.Session, error) { return &Session{}, nil } -// AnonymousLogin requires clients to authenticate using SMTP AUTH before sending emails -func (bkd *Backend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { - return nil, smtp.ErrAuthRequired -} - // A Session is returned after successful login. type Session struct{} +// AuthPlain implements authentication using SASL PLAIN. +func (s *Session) AuthPlain(username, password string) error { + if username != "username" || password != "password" { + return errors.New("Invalid username or password") + } + return nil +} + func (s *Session) Mail(from string, opts smtp.MailOptions) error { log.Println("Mail from:", from) return nil diff --git a/server.go b/server.go index a7c37c9..a1cb677 100644 --- a/server.go +++ b/server.go @@ -90,14 +90,12 @@ func NewServer(be Backend) *Server { return errors.New("Identities not supported") } - state := conn.State() - session, err := be.Login(&state, username, password) - if err != nil { - return err + sess := conn.Session() + if sess == nil { + panic("No session when AUTH is called") } - conn.SetSession(session) - return nil + return sess.AuthPlain(username, password) }) }, }, diff --git a/server_test.go b/server_test.go index f2bfe9f..be90c01 100644 --- a/server_test.go +++ b/server_test.go @@ -44,27 +44,7 @@ type backend struct { userErr error } -func (be *backend) Login(_ *smtp.ConnectionState, username, password string) (smtp.Session, error) { - if be.userErr != nil { - return &session{}, be.userErr - } - - if username != "username" || password != "password" { - return nil, errors.New("Invalid username or password") - } - - if be.implementLMTPData { - return &lmtpSession{&session{backend: be}}, nil - } - - return &session{backend: be}, nil -} - -func (be *backend) AnonymousLogin(_ *smtp.ConnectionState) (smtp.Session, error) { - if be.userErr != nil { - return &session{}, be.userErr - } - +func (be *backend) NewSession(_ smtp.ConnectionState, _ string) (smtp.Session, error) { if be.implementLMTPData { return &lmtpSession{&session{backend: be, anonymous: true}}, nil } @@ -83,6 +63,14 @@ type session struct { msg *message } +func (s *session) AuthPlain(username, password string) error { + if username != "username" || password != "password" { + return errors.New("Invalid username or password") + } + s.anonymous = false + return nil +} + func (s *session) Reset() { s.msg = &message{} } @@ -92,6 +80,9 @@ func (s *session) Logout() error { } func (s *session) Mail(from string, opts smtp.MailOptions) error { + if s.backend.userErr != nil { + return s.backend.userErr + } if s.backend.panicOnMail { panic("Everything is on fire!") }