Skip to content

Commit

Permalink
Close local connections, when remote SMTP server closes
Browse files Browse the repository at this point in the history
  • Loading branch information
kayrus committed Sep 12, 2022
1 parent 88735a2 commit e546646
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 31 deletions.
141 changes: 126 additions & 15 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package smtp

import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
Expand All @@ -14,6 +15,7 @@ import (
"net/textproto"
"strconv"
"strings"
"sync"
"time"

"github.com/emersion/go-sasl"
Expand All @@ -27,7 +29,7 @@ type Client struct {

// keep a reference to the connection so it can be used to create a TLS
// connection later
conn net.Conn
conn *monitoredConn
// whether the Client is using TLS
tls bool
serverName string
Expand All @@ -50,9 +52,13 @@ type Client struct {
DebugWriter io.Writer
}

// 30 seconds was chosen as it's the
// same duration as http.DefaultTransport's timeout.
var defaultTimeout = 30 * time.Second
const (
// 30 seconds was chosen as it's the
// same duration as http.DefaultTransport's timeout.
defaultTimeout = 30 * time.Second
// Doubled maximum line length per RFC 5321 (Section 4.5.3.1.6)
maxLineLimit = 2000
)

// Dial returns a new Client connected to an SMTP server at addr.
// The addr must include a port, as in "mail.example.com:smtp".
Expand Down Expand Up @@ -113,17 +119,94 @@ func NewClientLMTP(conn net.Conn, host string) (*Client, error) {
return c, nil
}

type monitoredConn struct {
net.Conn
pr *io.PipeReader
pw *io.PipeWriter
ctx context.Context
cancel func()
}

func (c monitoredConn) Read(b []byte) (int, error) {
return c.pr.Read(b)
}

func (c monitoredConn) monitorConn() {
var n int
var err error
var wg sync.WaitGroup

rb := make([]byte, 4096)
wb := make([]byte, 4096)

defer func() {
wg.Wait()
c.pw.CloseWithError(err)
c.cancel()
}()

for {
select {
case <-c.ctx.Done():
err = c.ctx.Err()
if err == context.Canceled {
err = io.EOF
}
return
default:
n, err = c.Conn.Read(rb)
if err != nil {
if err == io.EOF {
c.Conn.Close()
}
if n == 0 {
return
}
}
wg.Wait()
wg.Add(1)
if n > len(wb) {
wb = make([]byte, n)
}
n = copy(wb, rb[:n])
go func(b []byte) {
if _, err := c.pw.Write(b); err != nil {
panic(err)
}
wg.Done()
}(wb[:n])

if err != nil {
return
}
}
}
}

func (c *Client) setMonConn(conn net.Conn) *monitoredConn {
pr, pw := io.Pipe()

// initialize monitor stop func
ctx, cancel := context.WithCancel(context.Background())

monConn := &monitoredConn{conn, pr, pw, ctx, cancel}

// monitor closed connection
go monConn.monitorConn()

return monConn
}

// setConn sets the underlying network connection for the client.
func (c *Client) setConn(conn net.Conn) {
c.conn = conn
c.conn = c.setMonConn(conn)

var r io.Reader = conn
var w io.Writer = conn
var r io.Reader = c.conn
var w io.Writer = c.conn

r = &lineLimitReader{
R: conn,
// Doubled maximum line length per RFC 5321 (Section 4.5.3.1.6)
LineLimit: 2000,
R: r,
LineLimit: maxLineLimit,
}

r = io.TeeReader(r, clientDebugWriter{c})
Expand All @@ -136,7 +219,7 @@ func (c *Client) setConn(conn net.Conn) {
}{
Reader: r,
Writer: w,
Closer: conn,
Closer: c.conn,
}
c.Text = textproto.NewConn(rwc)

Expand All @@ -159,10 +242,10 @@ func (c *Client) InitConn(conn net.Conn) error {

_, _, err := c.Text.ReadResponse(220)
if err != nil {
c.Text.Close()
if protoErr, ok := err.(*textproto.Error); ok {
return toSMTPErr(protoErr)
}
c.Close()
return err
}

Expand Down Expand Up @@ -214,10 +297,17 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s

id, err := c.Text.Cmd(format, args...)
if err != nil {
if err == net.ErrClosed {
return c.readResponse(expectCode)
}
return 0, "", err
}
c.Text.StartResponse(id)
defer c.Text.EndResponse(id)
return c.readResponse(expectCode)
}

func (c *Client) readResponse(expectCode int) (int, string, error) {
code, msg, err := c.Text.ReadResponse(expectCode)
if err != nil {
if protoErr, ok := err.(*textproto.Error); ok {
Expand Down Expand Up @@ -279,8 +369,13 @@ func (c *Client) StartTLS(config *tls.Config) error {
if err := c.hello(); err != nil {
return err
}

// stop connection monitoring
c.conn.cancel()

_, _, err := c.cmd(220, "STARTTLS")
if err != nil {
c.setConn(c.conn.Conn)
return err
}
if config == nil {
Expand All @@ -294,15 +389,27 @@ func (c *Client) StartTLS(config *tls.Config) error {
if testHookStartTLS != nil {
testHookStartTLS(config)
}
c.setConn(tls.Client(c.conn, config))

conn := tls.Client(c.conn.Conn, config)
if c.CommandTimeout > 0 {
conn.SetDeadline(time.Now().Add(c.CommandTimeout))
defer conn.SetDeadline(time.Time{})
}

if err := conn.Handshake(); err != nil {
c.Close()
return err
}

c.setConn(conn)
return c.ehlo()
}

// TLSConnectionState returns the client's TLS connection state.
// The return values are their zero values if StartTLS did
// not succeed.
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
tc, ok := c.conn.(*tls.Conn)
tc, ok := c.conn.Conn.(*tls.Conn)
if !ok {
return
}
Expand Down Expand Up @@ -689,7 +796,11 @@ func (c *Client) Quit() error {
if err != nil {
return err
}
return c.Text.Close()
err = c.Close()
if err == net.ErrClosed {
return nil
}
return err
}

func parseEnhancedCode(s string) (EnhancedCode, error) {
Expand Down
Loading

0 comments on commit e546646

Please sign in to comment.