From e546646e362d4469c18bba93e3a3712ae627086e Mon Sep 17 00:00:00 2001 From: kayrus Date: Mon, 29 Aug 2022 17:20:08 +0200 Subject: [PATCH] Close local connections, when remote SMTP server closes --- client.go | 141 +++++++++++++++++++++++++++++++++++++++++++------ client_test.go | 92 ++++++++++++++++++++++++++------ 2 files changed, 202 insertions(+), 31 deletions(-) diff --git a/client.go b/client.go index 6632148..95653f2 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ package smtp import ( + "context" "crypto/tls" "encoding/base64" "errors" @@ -14,6 +15,7 @@ import ( "net/textproto" "strconv" "strings" + "sync" "time" "github.com/emersion/go-sasl" @@ -27,7 +29,7 @@ type Client struct { // keep a reference to the connection so it can be used to create a TLS // connection later - conn net.Conn + conn *monitoredConn // whether the Client is using TLS tls bool serverName string @@ -50,9 +52,13 @@ type Client struct { DebugWriter io.Writer } -// 30 seconds was chosen as it's the -// same duration as http.DefaultTransport's timeout. -var defaultTimeout = 30 * time.Second +const ( + // 30 seconds was chosen as it's the + // same duration as http.DefaultTransport's timeout. + defaultTimeout = 30 * time.Second + // Doubled maximum line length per RFC 5321 (Section 4.5.3.1.6) + maxLineLimit = 2000 +) // Dial returns a new Client connected to an SMTP server at addr. // The addr must include a port, as in "mail.example.com:smtp". @@ -113,17 +119,94 @@ func NewClientLMTP(conn net.Conn, host string) (*Client, error) { return c, nil } +type monitoredConn struct { + net.Conn + pr *io.PipeReader + pw *io.PipeWriter + ctx context.Context + cancel func() +} + +func (c monitoredConn) Read(b []byte) (int, error) { + return c.pr.Read(b) +} + +func (c monitoredConn) monitorConn() { + var n int + var err error + var wg sync.WaitGroup + + rb := make([]byte, 4096) + wb := make([]byte, 4096) + + defer func() { + wg.Wait() + c.pw.CloseWithError(err) + c.cancel() + }() + + for { + select { + case <-c.ctx.Done(): + err = c.ctx.Err() + if err == context.Canceled { + err = io.EOF + } + return + default: + n, err = c.Conn.Read(rb) + if err != nil { + if err == io.EOF { + c.Conn.Close() + } + if n == 0 { + return + } + } + wg.Wait() + wg.Add(1) + if n > len(wb) { + wb = make([]byte, n) + } + n = copy(wb, rb[:n]) + go func(b []byte) { + if _, err := c.pw.Write(b); err != nil { + panic(err) + } + wg.Done() + }(wb[:n]) + + if err != nil { + return + } + } + } +} + +func (c *Client) setMonConn(conn net.Conn) *monitoredConn { + pr, pw := io.Pipe() + + // initialize monitor stop func + ctx, cancel := context.WithCancel(context.Background()) + + monConn := &monitoredConn{conn, pr, pw, ctx, cancel} + + // monitor closed connection + go monConn.monitorConn() + + return monConn +} + // setConn sets the underlying network connection for the client. func (c *Client) setConn(conn net.Conn) { - c.conn = conn + c.conn = c.setMonConn(conn) - var r io.Reader = conn - var w io.Writer = conn + var r io.Reader = c.conn + var w io.Writer = c.conn r = &lineLimitReader{ - R: conn, - // Doubled maximum line length per RFC 5321 (Section 4.5.3.1.6) - LineLimit: 2000, + R: r, + LineLimit: maxLineLimit, } r = io.TeeReader(r, clientDebugWriter{c}) @@ -136,7 +219,7 @@ func (c *Client) setConn(conn net.Conn) { }{ Reader: r, Writer: w, - Closer: conn, + Closer: c.conn, } c.Text = textproto.NewConn(rwc) @@ -159,10 +242,10 @@ func (c *Client) InitConn(conn net.Conn) error { _, _, err := c.Text.ReadResponse(220) if err != nil { - c.Text.Close() if protoErr, ok := err.(*textproto.Error); ok { return toSMTPErr(protoErr) } + c.Close() return err } @@ -214,10 +297,17 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s id, err := c.Text.Cmd(format, args...) if err != nil { + if err == net.ErrClosed { + return c.readResponse(expectCode) + } return 0, "", err } c.Text.StartResponse(id) defer c.Text.EndResponse(id) + return c.readResponse(expectCode) +} + +func (c *Client) readResponse(expectCode int) (int, string, error) { code, msg, err := c.Text.ReadResponse(expectCode) if err != nil { if protoErr, ok := err.(*textproto.Error); ok { @@ -279,8 +369,13 @@ func (c *Client) StartTLS(config *tls.Config) error { if err := c.hello(); err != nil { return err } + + // stop connection monitoring + c.conn.cancel() + _, _, err := c.cmd(220, "STARTTLS") if err != nil { + c.setConn(c.conn.Conn) return err } if config == nil { @@ -294,7 +389,19 @@ func (c *Client) StartTLS(config *tls.Config) error { if testHookStartTLS != nil { testHookStartTLS(config) } - c.setConn(tls.Client(c.conn, config)) + + conn := tls.Client(c.conn.Conn, config) + if c.CommandTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.CommandTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + if err := conn.Handshake(); err != nil { + c.Close() + return err + } + + c.setConn(conn) return c.ehlo() } @@ -302,7 +409,7 @@ func (c *Client) StartTLS(config *tls.Config) error { // The return values are their zero values if StartTLS did // not succeed. func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { - tc, ok := c.conn.(*tls.Conn) + tc, ok := c.conn.Conn.(*tls.Conn) if !ok { return } @@ -689,7 +796,11 @@ func (c *Client) Quit() error { if err != nil { return err } - return c.Text.Close() + err = c.Close() + if err == net.ErrClosed { + return nil + } + return err } func parseEnhancedCode(s string) (EnhancedCode, error) { diff --git a/client_test.go b/client_test.go index 4574eec..9151d25 100644 --- a/client_test.go +++ b/client_test.go @@ -9,9 +9,9 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "errors" "io" "net" - "net/textproto" "reflect" "strings" "testing" @@ -79,7 +79,10 @@ func TestBasic(t *testing.T) { bcmdbuf := bufio.NewWriter(&cmdbuf) var fake faker fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c := &Client{Text: textproto.NewConn(fake), conn: fake, localName: "localhost"} + c := &Client{localName: "localhost"} + if err := c.InitConn(fake); err != nil { + t.Fatalf("Init failed: %s", err) + } if err := c.helo(); err != nil { t.Fatalf("HELO failed: %s", err) @@ -251,7 +254,6 @@ func TestBasic_SMTPError(t *testing.T) { func TestClient_TooLongLine(t *testing.T) { faultyServer := []string{ - "220 mx.google.com at your service\r\n", "220 mx.google.com at your service\r\n", "500 5.0.0 nU6XC5JJUfiuIkC7NhrxZz36Rl/rXpkfx9QdeZJ+rno6W5J9k9HvniyWXBBi1gOZ/CUXEI6K7Uony70eiVGGGkdFhP1rEvMGny1dqIRo3NM2NifrvvLIKGeX6HrYmkc7NMn9BwHyAnt5oLe5eNVDI+grwIikVPNVFZi0Dg4Xatdg5Cs8rH1x9BWhqyDoxosJst4wRoX4AymYygUcftM3y16nVg/qcb1GJwxSNbah7VjOiSrk6MlTdGR/2AwIIcSw7pZVJjGbCorniOTvKBcyut1YdbrX/4a/dBhvLfZtdSccqyMZAdZno+tGrnu+N2ghFvz6cx6bBab9Z4JJQMlkK/g1y7xjEPr6nKwruAf71NzOclPK5wzs2hY3Ku9xEjU0Cd+g/OjAzVsmeJk2U0q+vmACZsFAiOlRynXKFPLqMAg8skM5lioRTm05K/u3aBaUq0RKloeBHZ/zNp/kfHNp6TmJKAzvsXD3Xdo+PRAgCZRTRAl3ydGdrOOjxTULCVlgOL6xSAJdj9zGkzQoEW4tRmp1OiIab4GSxCtkIo7XnAowJ7EPUfDGTV3hhl5Qn7jvZjPCPlruRTtzVTho7D3HBEouWv1qDsqdED23myw0Ma9ZlobSf9eHqsSv1MxjKG2D5DdFBACu6pXGz3ceGreOHYWnI74TkoHtQ5oNuF6VUkGjGN+f4fOaiypQ54GJ8skTNoSCHLK4XF8ZutSxWzMR+LKoJBWMb6bdAiFNt+vXZOUiTgmTqs6Sw79JXqDX9YFxryJMKjHMiFkm+RZbaK5sIOXqyq+RNmOJ+G0unrQHQMCES476c7uvOlYrNoJtq+uox1qFdisIE/8vfSoKBlTtw+r2m87djIQh4ip/hVmalvtiF5fnVTxigbtwLWv8rAOCXKoktU0c2ie0a5hGtvZT0SXxwX8K2CeYXb81AFD2IaLt/p8Q4WuZ82eOCeXP72qP9yWYj6mIZdgyimm8wjrDowt2yPJU28ZD6k3Ei6C31OKgMpCf8+MW504/VCwld7czAIwjJiZe3DxtUdfM7Q565OzLiWQgI8fxjsvlCKMiOY7q42IGGsVxXJAFMtDKdchgqQA1PJR1vrw+SbI3Mh4AGnn8vKn+WTsieB3qkloo7MZlpMz/bwPXg7XadOVkUaVeHrZ5OsqDWhsWOLtPZLi5XdNazPzn9uxWbpelXEBKAjZzfoawSUgGT5vCYACNfz/yIw1DB067N+HN1KvVddI6TNBA32lpqkQ6VwdWztq6pREE51sNl9p7MUzr+ef0331N5DqQsy+epmRDwebosCx15l/rpvBc91OnxmMMXDNtmxSzVxaZjyGDmJ7RDdTy/Su76AlaMP1zxivxg2MU/9zyTzM16coIAMOd/6Uo9ezKgbZEPeMROKTzAld9BhK9BBPWofoQ0mBkVc7btnahQe3u8HoD6SKCkr9xcTcC9ZKpLkc4svrmxT9e0858pjhis9BbWD/owa6552n2+KwUMRyB8ys7rPL86hh9lBTS+05cVL+BmJfNHOA6ZizdGc3lpwIVbFmzMR5BM0HRf3OCntkWojgsdsP8BGZWHiCGGqA7YGa5AOleR887r8Zhyp47DT3Cn3Rg/icYurIx7Yh0p696gxfANo4jEkE2BOroIscDnhauwck5CCJMcabpTrGwzK8NJ+xZnCUplXnZiIaj85Uh9+yI670B4bybWlZoVmALUxxuQ8bSMAp7CAzMcMWbYJHwBqLF8V2qMj3/g81S3KOptn8b7Idh7IMzAkV8VxE3qAguzwS0zEu8l894sOFUPiJq2/llFeiHNOcEQUGJ+8ATJSAFOMDXAeQS2FoIDOYdesO6yacL0zUkvDydWbA84VXHW8DvdHPli/8hmc++dn5CXSDeBJfC/yypvrpLgkSilZMuHEYHEYHEYEHYEHEYEHEYEHEYEYEYEYEYEYEYEYEYEYEYEYEYEYEYEYEYEYEYYEYEYEYEYEYEYEYYEYEYEYEYEYEYEYEY\r\n", "220 2.0.0 Kk\r\n", @@ -295,7 +297,8 @@ func TestClient_TooLongLine(t *testing.T) { } } -var basicServer = `250 mx.google.com at your service +var basicServer = `220 mx.google.com at your service +250 mx.google.com at your service 502 Unrecognized command. 250-mx.google.com at your service 250-SIZE 35651584 @@ -420,7 +423,6 @@ QUIT ` func TestHello(t *testing.T) { - if len(helloServer) != len(helloClient) { t.Fatalf("Hello server and client size mismatch") } @@ -610,7 +612,7 @@ func TestTLSClient(t *testing.T) { t.Fatalf("failed to accept connection: %v", err) } defer conn.Close() - if err := serverHandle(conn, t); err != nil { + if err := serverHandle(conn, t, false); err != nil { t.Fatalf("failed to handle connection: %v", err) } if err := <-errc; err != nil { @@ -631,7 +633,7 @@ func TestTLSConnState(t *testing.T) { return } defer c.Close() - if err := serverHandle(c, t); err != nil { + if err := serverHandle(c, t, false); err != nil { t.Errorf("server error: %v", err) } }() @@ -662,6 +664,26 @@ func TestTLSConnState(t *testing.T) { <-serverDone } +func TestClosedConn(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + errc := make(chan error) + go func() { + errc <- smtpClientNoop(ln.Addr().String()) + }() + conn, err := ln.Accept() + if err != nil { + t.Fatalf("failed to accept connection: %v", err) + } + defer conn.Close() + if err := serverHandle(conn, t, true); err != nil { + t.Fatalf("failed to handle connection: %v", err) + } + if err := <-errc; err != net.ErrClosed { + t.Fatalf("client error: %v", err) + } +} + func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -682,7 +704,7 @@ func (s smtpSender) send(f string) { } // smtp server, finely tailored to deal with our own client only! -func serverHandle(c net.Conn, t *testing.T) error { +func serverHandle(c net.Conn, t *testing.T, forceClose bool) error { send := smtpSender{c}.send send("220 127.0.0.1 ESMTP service ready") s := bufio.NewScanner(c) @@ -699,9 +721,13 @@ func serverHandle(c net.Conn, t *testing.T) error { return err } config := &tls.Config{Certificates: []tls.Certificate{keypair}} - c = tls.Server(c, config) - defer c.Close() - return serverHandleTLS(c, t) + tlsConn := tls.Server(c, config) + err = tlsConn.Handshake() + if err != nil { + return err + } + defer tlsConn.Close() + return serverHandleTLS(tlsConn, t, forceClose) default: t.Fatalf("unrecognized command: %q", s.Text()) } @@ -709,13 +735,17 @@ func serverHandle(c net.Conn, t *testing.T) error { return s.Err() } -func serverHandleTLS(c net.Conn, t *testing.T) error { +func serverHandleTLS(c net.Conn, t *testing.T, forceClose bool) error { send := smtpSender{c}.send s := bufio.NewScanner(c) for s.Scan() { switch s.Text() { case "EHLO localhost": send("250 Ok") + if forceClose { + send("221 127.0.0.1 Service closing transmission channel") + return nil + } case "MAIL FROM:": send("250 Ok") case "RCPT TO:": @@ -745,6 +775,28 @@ func init() { } } +func smtpClientNoop(addr string) error { + c, err := Dial(addr) + if err != nil { + return err + } + + if err = c.hello(); err != nil { + return err + } + if ok, _ := c.Extension("STARTTLS"); !ok { + return errors.New("smtp: server doesn't support STARTTLS") + } + if err = c.StartTLS(nil); err != nil { + return err + } + + // sleep and wait for the server to close the conn + time.Sleep(time.Second * 1) + + return c.Close() +} + func sendMail(hostPort string) error { from := "joe1@example.com" to := []string{"joe2@example.com"} @@ -796,7 +848,10 @@ func TestLMTP(t *testing.T) { bcmdbuf := bufio.NewWriter(&cmdbuf) var fake faker fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c := &Client{Text: textproto.NewConn(fake), conn: fake, lmtp: true} + c := &Client{lmtp: true} + if err := c.InitConn(fake); err != nil { + t.Fatalf("Init failed: %s", err) + } if err := c.Hello("localhost"); err != nil { t.Fatalf("LHLO failed: %s", err) @@ -838,7 +893,8 @@ Goodbye.` } } -var lmtpServer = `250-localhost at your service +var lmtpServer = `220 localhost at your service +250-localhost at your service 250-SIZE 35651584 250 8BITMIME 250 Sender OK @@ -864,7 +920,8 @@ QUIT ` func TestLMTPData(t *testing.T) { - var lmtpServerPartial = `250-localhost at your service + var lmtpServerPartial = `220 localhost at your service +250-localhost at your service 250-SIZE 35651584 250 8BITMIME 250 Sender OK @@ -881,7 +938,10 @@ func TestLMTPData(t *testing.T) { bcmdbuf := bufio.NewWriter(&cmdbuf) var fake faker fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c := &Client{Text: textproto.NewConn(fake), conn: fake, lmtp: true} + c := &Client{lmtp: true} + if err := c.InitConn(fake); err != nil { + t.Fatalf("Init failed: %s", err) + } if err := c.Hello("localhost"); err != nil { t.Fatalf("LHLO failed: %s", err)