From 3871b2be44124484903057d4973ddf05488be071 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Thu, 26 Sep 2024 11:51:30 +0200 Subject: [PATCH] Lock client connections and update deadline handling Add mutex locking for client connections to ensure thread safety. Introduce `HasConnection` method to check active connections and `UpdateDeadline` method to handle timeout updates. Refactor connection handling in `checkConn` and `tls` methods accordingly. --- client.go | 18 +++++++++++------- client_120.go | 2 ++ smtp/smtp.go | 14 ++++++++++++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index fac9a34..77f9751 100644 --- a/client.go +++ b/client.go @@ -12,6 +12,7 @@ import ( "net" "os" "strings" + "sync" "time" "github.com/wneessen/go-mail/log" @@ -87,6 +88,7 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con // Client is the SMTP client struct type Client struct { + mutex sync.RWMutex // connection is the net.Conn that the smtp.Client is based on connection net.Conn @@ -589,6 +591,9 @@ func (c *Client) setDefaultHelo() error { // DialWithContext establishes a connection to the SMTP server with a given context.Context func (c *Client) DialWithContext(dialCtx context.Context) error { + c.mutex.Lock() + defer c.mutex.Unlock() + ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) defer cancel() @@ -602,17 +607,16 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { c.dialContextFunc = tlsDialer.DialContext } } - var err error - c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr()) + connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) if err != nil && c.fallbackPort != 0 { // TODO: should we somehow log or append the previous error? - c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) + connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) } if err != nil { return err } - client, err := smtp.NewClient(c.connection, c.host) + client, err := smtp.NewClient(connection, c.host) if err != nil { return err } @@ -691,7 +695,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e // checkConn makes sure that a required server connection is available and extends the // connection deadline func (c *Client) checkConn() error { - if c.connection == nil { + if !c.smtpClient.HasConnection() { return ErrNoActiveConnection } @@ -701,7 +705,7 @@ func (c *Client) checkConn() error { } } - if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil { + if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil { return ErrDeadlineExtendFailed } return nil @@ -715,7 +719,7 @@ func (c *Client) serverFallbackAddr() string { // tls tries to make sure that the STARTTLS requirements are satisfied func (c *Client) tls() error { - if c.connection == nil { + if !c.smtpClient.HasConnection() { return ErrNoActiveConnection } if !c.useSSL && c.tlspolicy != NoTLS { diff --git a/client_120.go b/client_120.go index 4f82aa7..729069b 100644 --- a/client_120.go +++ b/client_120.go @@ -13,6 +13,8 @@ import ( // Send sends out the mail message func (c *Client) Send(messages ...*Msg) (returnErr error) { + c.mutex.Lock() + defer c.mutex.Unlock() if err := c.checkConn(); err != nil { returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} return diff --git a/smtp/smtp.go b/smtp/smtp.go index d2a0e64..5f5484a 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -30,6 +30,7 @@ import ( "net/textproto" "os" "strings" + "time" "github.com/wneessen/go-mail/log" ) @@ -472,6 +473,19 @@ func (c *Client) SetDSNRcptNotifyOption(d string) { c.dsnrntype = d } +// HasConnection checks if the client has an active connection. +// Returns true if the `conn` field is not nil, indicating an active connection. +func (c *Client) HasConnection() bool { + return c.conn != nil +} + +func (c *Client) UpdateDeadline(timeout time.Duration) error { + if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return fmt.Errorf("smtp: failed to update deadline: %w", err) + } + return nil +} + // debugLog checks if the debug flag is set and if so logs the provided message to // the log.Logger interface func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) {