From 4e0e40027ef7151851185281114896c56582463c Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Mon, 2 May 2022 18:46:19 +0200 Subject: [PATCH] client: delay greeting This makes it so NewClient never blocks, much like tls.Client. This allows callers to have better control over timeouts. --- client.go | 58 +++++++++++++++++++++++-------------------------- client_test.go | 59 +++++++++++++++----------------------------------- 2 files changed, 44 insertions(+), 73 deletions(-) diff --git a/client.go b/client.go index 30fa4e6..4149240 100644 --- a/client.go +++ b/client.go @@ -62,10 +62,7 @@ func Dial(addr string) (*Client, error) { if err != nil { return nil, err } - client, err := NewClient(conn) - if err != nil { - return nil, err - } + client := NewClient(conn) client.serverName, _, _ = net.SplitHostPort(addr) return client, nil } @@ -83,17 +80,14 @@ func DialTLS(addr string, tlsConfig *tls.Config) (*Client, error) { if err != nil { return nil, err } - client, err := NewClient(conn) - if err != nil { - return nil, err - } + client := NewClient(conn) client.serverName, _, _ = net.SplitHostPort(addr) return client, nil } // NewClient returns a new Client using an existing connection and host as a // server name to be used when authenticating. -func NewClient(conn net.Conn) (*Client, error) { +func NewClient(conn net.Conn) *Client { c := &Client{ localName: "localhost", // As recommended by RFC 5321. For DATA command reply (3xx one) RFC @@ -107,31 +101,15 @@ func NewClient(conn net.Conn) (*Client, error) { c.setConn(conn) - // Initial greeting timeout. RFC 5321 recommends 5 minutes. - c.conn.SetDeadline(time.Now().Add(5 * time.Minute)) - defer c.conn.SetDeadline(time.Time{}) - - _, _, err := c.text.ReadResponse(220) - if err != nil { - c.text.Close() - if protoErr, ok := err.(*textproto.Error); ok { - return nil, toSMTPErr(protoErr) - } - return nil, err - } - - return c, nil + return c } // NewClientLMTP returns a new LMTP Client (as defined in RFC 2033) using an // existing connection and host as a server name to be used when authenticating. -func NewClientLMTP(conn net.Conn) (*Client, error) { - c, err := NewClient(conn) - if err != nil { - return nil, err - } +func NewClientLMTP(conn net.Conn) *Client { + c := NewClient(conn) c.lmtp = true - return c, nil + return c } // setConn sets the underlying network connection for the client. @@ -170,12 +148,30 @@ func (c *Client) Close() error { return c.text.Close() } +func (c *Client) greet() error { + // Initial greeting timeout. RFC 5321 recommends 5 minutes. + c.conn.SetDeadline(time.Now().Add(c.CommandTimeout)) + defer c.conn.SetDeadline(time.Time{}) + + _, _, err := c.text.ReadResponse(220) + if err != nil { + c.text.Close() + if protoErr, ok := err.(*textproto.Error); ok { + return toSMTPErr(protoErr) + } + return err + } + + return nil +} + // hello runs a hello exchange if needed. func (c *Client) hello() error { if !c.didHello { c.didHello = true - err := c.ehlo() - if err != nil { + if err := c.greet(); err != nil { + c.helloError = err + } else if err := c.ehlo(); err != nil { c.helloError = c.helo() } } diff --git a/client_test.go b/client_test.go index 6756091..b6693c5 100644 --- a/client_test.go +++ b/client_test.go @@ -33,10 +33,7 @@ func TestClientAuthTrimSpace(t *testing.T) { strings.NewReader(server), &wrote, } - c, err := NewClient(fake) - if err != nil { - t.Fatalf("NewClient: %v", err) - } + c := NewClient(fake) c.tls = true c.didHello = true c.Auth(toServerEmptyAuth{}) @@ -185,12 +182,9 @@ func TestBasic_SMTPError(t *testing.T) { strings.NewReader(faultyServer), &wrote, } - c, err := NewClient(fake) - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } + c := NewClient(fake) - err = c.Mail("whatever", nil) + err := c.Mail("whatever", nil) if err == nil { t.Fatal("MAIL succeeded") } @@ -267,12 +261,9 @@ func TestClient_TooLongLine(t *testing.T) { pr, &wrote, } - c, err := NewClient(fake) - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } + c := NewClient(fake) - err = c.Mail("whatever", nil) + err := c.Mail("whatever", nil) if err != ErrTooLongLine { t.Fatal("MAIL succeeded or returned a different error:", err) } @@ -335,10 +326,7 @@ func TestNewClient(t *testing.T) { } var fake faker fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake) - if err != nil { - t.Fatalf("NewClient: %v\n(after %v)", err, out()) - } + c := NewClient(fake) defer c.Close() if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" { t.Fatalf("Expected AUTH supported") @@ -376,10 +364,7 @@ func TestNewClient2(t *testing.T) { bcmdbuf := bufio.NewWriter(&cmdbuf) var fake faker fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake) - if err != nil { - t.Fatalf("NewClient: %v", err) - } + c := NewClient(fake) defer c.Close() if ok, _ := c.Extension("DSN"); ok { t.Fatalf("Shouldn't support DSN") @@ -422,15 +407,12 @@ func TestHello(t *testing.T) { bcmdbuf := bufio.NewWriter(&cmdbuf) var fake faker fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake) - if err != nil { - t.Fatalf("NewClient: %v", err) - } + c := NewClient(fake) defer c.Close() c.serverName = "fake.host" c.localName = "customhost" - err = nil + var err error switch i { case 0: err = c.Hello("hostinjection>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n") @@ -553,15 +535,12 @@ func TestAuthFailed(t *testing.T) { bcmdbuf := bufio.NewWriter(&cmdbuf) var fake faker fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake) - if err != nil { - t.Fatalf("NewClient: %v", err) - } + c := NewClient(fake) defer c.Close() c.tls = true c.serverName = "smtp.google.com" - err = c.Auth(sasl.NewPlainClient("", "user", "pass")) + err := c.Auth(sasl.NewPlainClient("", "user", "pass")) if err == nil { t.Error("Auth: expected error; got none") @@ -830,7 +809,8 @@ Goodbye.` } } -var lmtpServer = `250-localhost at your service +var lmtpServer = `220 localhost Simple Mail Transfer Service Ready +250-localhost at your service 250-SIZE 35651584 250 8BITMIME 250 Sender OK @@ -856,7 +836,8 @@ QUIT ` func TestLMTPData(t *testing.T) { - var lmtpServerPartial = `250-localhost at your service + var lmtpServerPartial = `220 localhost Simple Mail Transfer Service Ready +250-localhost at your service 250-SIZE 35651584 250 8BITMIME 250 Sender OK @@ -951,10 +932,7 @@ func TestClientXtext(t *testing.T) { strings.NewReader(server), &wrote, } - c, err := NewClient(fake) - if err != nil { - t.Fatalf("NewClient: %v", err) - } + c := NewClient(fake) c.didHello = true c.ext = map[string]string{"AUTH": "PLAIN", "DSN": ""} email := "e=mc2@example.com" @@ -1001,10 +979,7 @@ func TestClientDSN(t *testing.T) { strings.NewReader(server), &wrote, } - c, err := NewClient(fake) - if err != nil { - t.Fatalf("NewClient: %v", err) - } + c := NewClient(fake) c.didHello = true c.ext = map[string]string{"DSN": ""} c.Mail(dsnEmailRFC822, &MailOptions{