From 17e0dbae352ac960eb92320af2cb90e8bcedd263 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Wed, 21 Feb 2024 11:07:13 +0100 Subject: [PATCH] server: replace EnableAuth with AuthSession Closes: https://github.com/emersion/go-smtp/issues/170 --- backend.go | 15 +++++++++++++++ conn.go | 46 ++++++++++++++++++++++++++++++++++++++++------ server.go | 36 +++--------------------------------- 3 files changed, 58 insertions(+), 39 deletions(-) diff --git a/backend.go b/backend.go index 1241ff6..594b22d 100644 --- a/backend.go +++ b/backend.go @@ -2,6 +2,8 @@ package smtp import ( "io" + + "github.com/emersion/go-sasl" ) var ( @@ -20,6 +22,11 @@ var ( EnhancedCode: EnhancedCode{5, 7, 0}, Message: "Authentication not supported", } + ErrAuthUnknownMechanism = &SMTPError{ + Code: 504, + EnhancedCode: EnhancedCode{5, 7, 4}, + Message: "Unsupported authentication mechanism", + } ) // A SMTP server backend. @@ -76,3 +83,11 @@ type LMTPSession interface { type StatusCollector interface { SetStatus(rcptTo string, err error) } + +// AuthSession is a session supporting authentication. +type AuthSession interface { + Session + + AuthMechanisms() []string + Auth(mech string) (sasl.Server, error) +} diff --git a/conn.go b/conn.go index f4d87c0..5089dc3 100644 --- a/conn.go +++ b/conn.go @@ -15,6 +15,8 @@ import ( "strings" "sync" "time" + + "github.com/emersion/go-sasl" ) // Number of errors we'll tolerate per connection before closing. Defaults to 3. @@ -257,7 +259,7 @@ func (c *Conn) handleGreet(enhanced bool, arg string) { } if c.authAllowed() { authCap := "AUTH" - for name := range c.server.auths { + for _, name := range c.authMechanisms() { authCap += " " + name } @@ -786,14 +788,16 @@ func (c *Conn) handleAuth(arg string) { } } - newSasl, ok := c.server.auths[mechanism] - if !ok { - c.writeResponse(504, EnhancedCode{5, 7, 4}, "Unsupported authentication mechanism") + sasl, err := c.auth(mechanism) + if err != nil { + if smtpErr, ok := err.(*SMTPError); ok { + c.writeResponse(smtpErr.Code, smtpErr.EnhancedCode, smtpErr.Message) + } else { + c.writeResponse(454, EnhancedCode{4, 7, 0}, err.Error()) + } return } - sasl := newSasl(c) - response := ir for { challenge, done, err := sasl.Next(response) @@ -838,6 +842,36 @@ func (c *Conn) handleAuth(arg string) { c.didAuth = true } +func (c *Conn) authMechanisms() []string { + if authSession, ok := c.Session().(AuthSession); ok { + return authSession.AuthMechanisms() + } + return []string{sasl.Plain} +} + +func (c *Conn) auth(mech string) (sasl.Server, error) { + if authSession, ok := c.Session().(AuthSession); ok { + return authSession.Auth(mech) + } + + if mech != sasl.Plain { + return nil, ErrAuthUnknownMechanism + } + + return sasl.NewPlainServer(func(identity, username, password string) error { + if identity != "" && identity != username { + return errors.New("identities not supported") + } + + sess := c.Session() + if sess == nil { + panic("No session when AUTH is called") + } + + return sess.AuthPlain(username, password) + }), nil +} + func (c *Conn) handleStartTLS() { if _, isTLS := c.TLSConnectionState(); isTLS { c.writeResponse(502, EnhancedCode{5, 5, 1}, "Already running in TLS") diff --git a/server.go b/server.go index 23e4382..c2e5a52 100644 --- a/server.go +++ b/server.go @@ -10,17 +10,12 @@ import ( "os" "sync" "time" - - "github.com/emersion/go-sasl" ) var ( ErrServerClosed = errors.New("smtp: server already closed") ) -// A function that creates SASL servers. -type SaslServerFactory func(conn *Conn) sasl.Server - // Logger interface is used by Server to report unexpected internal errors. type Logger interface { Printf(format string, v ...interface{}) @@ -73,9 +68,8 @@ type Server struct { wg sync.WaitGroup - caps []string - auths map[string]SaslServerFactory - done chan struct{} + caps []string + done chan struct{} locker sync.Mutex listeners []net.Listener @@ -92,23 +86,7 @@ func NewServer(be Backend) *Server { done: make(chan struct{}, 1), ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags), caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"}, - auths: map[string]SaslServerFactory{ - sasl.Plain: func(conn *Conn) sasl.Server { - return sasl.NewPlainServer(func(identity, username, password string) error { - if identity != "" && identity != username { - return errors.New("identities not supported") - } - - sess := conn.Session() - if sess == nil { - panic("No session when AUTH is called") - } - - return sess.AuthPlain(username, password) - }) - }, - }, - conns: make(map[*Conn]struct{}), + conns: make(map[*Conn]struct{}), } } @@ -329,11 +307,3 @@ func (s *Server) Shutdown(ctx context.Context) error { return err } } - -// EnableAuth enables an authentication mechanism on this server. -// -// This function should not be called directly, it must only be used by -// libraries implementing extensions of the SMTP protocol. -func (s *Server) EnableAuth(name string, f SaslServerFactory) { - s.auths[name] = f -}