Skip to content

Commit

Permalink
server: replace EnableAuth with AuthSession
Browse files Browse the repository at this point in the history
Closes: #170
  • Loading branch information
emersion committed Feb 21, 2024
1 parent 33fe6a6 commit 17e0dba
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 39 deletions.
15 changes: 15 additions & 0 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package smtp

import (
"io"

"github.com/emersion/go-sasl"
)

var (
Expand All @@ -20,6 +22,11 @@ var (
EnhancedCode: EnhancedCode{5, 7, 0},
Message: "Authentication not supported",
}
ErrAuthUnknownMechanism = &SMTPError{
Code: 504,
EnhancedCode: EnhancedCode{5, 7, 4},
Message: "Unsupported authentication mechanism",
}
)

// A SMTP server backend.
Expand Down Expand Up @@ -76,3 +83,11 @@ type LMTPSession interface {
type StatusCollector interface {
SetStatus(rcptTo string, err error)
}

// AuthSession is a session supporting authentication.
type AuthSession interface {
Session

AuthMechanisms() []string
Auth(mech string) (sasl.Server, error)
}
46 changes: 40 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"strings"
"sync"
"time"

"github.com/emersion/go-sasl"
)

// Number of errors we'll tolerate per connection before closing. Defaults to 3.
Expand Down Expand Up @@ -257,7 +259,7 @@ func (c *Conn) handleGreet(enhanced bool, arg string) {
}
if c.authAllowed() {
authCap := "AUTH"
for name := range c.server.auths {
for _, name := range c.authMechanisms() {
authCap += " " + name
}

Expand Down Expand Up @@ -786,14 +788,16 @@ func (c *Conn) handleAuth(arg string) {
}
}

newSasl, ok := c.server.auths[mechanism]
if !ok {
c.writeResponse(504, EnhancedCode{5, 7, 4}, "Unsupported authentication mechanism")
sasl, err := c.auth(mechanism)
if err != nil {
if smtpErr, ok := err.(*SMTPError); ok {
c.writeResponse(smtpErr.Code, smtpErr.EnhancedCode, smtpErr.Message)
} else {
c.writeResponse(454, EnhancedCode{4, 7, 0}, err.Error())
}
return
}

sasl := newSasl(c)

response := ir
for {
challenge, done, err := sasl.Next(response)
Expand Down Expand Up @@ -838,6 +842,36 @@ func (c *Conn) handleAuth(arg string) {
c.didAuth = true
}

func (c *Conn) authMechanisms() []string {
if authSession, ok := c.Session().(AuthSession); ok {
return authSession.AuthMechanisms()
}
return []string{sasl.Plain}
}

func (c *Conn) auth(mech string) (sasl.Server, error) {
if authSession, ok := c.Session().(AuthSession); ok {
return authSession.Auth(mech)
}

if mech != sasl.Plain {
return nil, ErrAuthUnknownMechanism
}

return sasl.NewPlainServer(func(identity, username, password string) error {
if identity != "" && identity != username {
return errors.New("identities not supported")
}

sess := c.Session()
if sess == nil {
panic("No session when AUTH is called")
}

return sess.AuthPlain(username, password)
}), nil
}

func (c *Conn) handleStartTLS() {
if _, isTLS := c.TLSConnectionState(); isTLS {
c.writeResponse(502, EnhancedCode{5, 5, 1}, "Already running in TLS")
Expand Down
36 changes: 3 additions & 33 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,12 @@ import (
"os"
"sync"
"time"

"github.com/emersion/go-sasl"
)

var (
ErrServerClosed = errors.New("smtp: server already closed")
)

// A function that creates SASL servers.
type SaslServerFactory func(conn *Conn) sasl.Server

// Logger interface is used by Server to report unexpected internal errors.
type Logger interface {
Printf(format string, v ...interface{})
Expand Down Expand Up @@ -73,9 +68,8 @@ type Server struct {

wg sync.WaitGroup

caps []string
auths map[string]SaslServerFactory
done chan struct{}
caps []string
done chan struct{}

locker sync.Mutex
listeners []net.Listener
Expand All @@ -92,23 +86,7 @@ func NewServer(be Backend) *Server {
done: make(chan struct{}, 1),
ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags),
caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"},
auths: map[string]SaslServerFactory{
sasl.Plain: func(conn *Conn) sasl.Server {
return sasl.NewPlainServer(func(identity, username, password string) error {
if identity != "" && identity != username {
return errors.New("identities not supported")
}

sess := conn.Session()
if sess == nil {
panic("No session when AUTH is called")
}

return sess.AuthPlain(username, password)
})
},
},
conns: make(map[*Conn]struct{}),
conns: make(map[*Conn]struct{}),
}
}

Expand Down Expand Up @@ -329,11 +307,3 @@ func (s *Server) Shutdown(ctx context.Context) error {
return err
}
}

// EnableAuth enables an authentication mechanism on this server.
//
// This function should not be called directly, it must only be used by
// libraries implementing extensions of the SMTP protocol.
func (s *Server) EnableAuth(name string, f SaslServerFactory) {
s.auths[name] = f
}

0 comments on commit 17e0dba

Please sign in to comment.