Lock client connections and update deadline handling

Add mutex locking for client connections to ensure thread safety. Introduce `HasConnection` method to check active connections and `UpdateDeadline` method to handle timeout updates. Refactor connection handling in `checkConn` and `tls` methods accordingly.
This commit is contained in:
Winni Neessen 2024-09-26 11:51:30 +02:00
parent 8683917c3d
commit 3871b2be44
Signed by: wneessen
GPG key ID: 385AC9889632126E
3 changed files with 27 additions and 7 deletions

View file

@ -12,6 +12,7 @@ import (
"net"
"os"
"strings"
"sync"
"time"
"github.com/wneessen/go-mail/log"
@ -87,6 +88,7 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
// Client is the SMTP client struct
type Client struct {
mutex sync.RWMutex
// connection is the net.Conn that the smtp.Client is based on
connection net.Conn
@ -589,6 +591,9 @@ func (c *Client) setDefaultHelo() error {
// DialWithContext establishes a connection to the SMTP server with a given context.Context
func (c *Client) DialWithContext(dialCtx context.Context) error {
c.mutex.Lock()
defer c.mutex.Unlock()
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
defer cancel()
@ -602,17 +607,16 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
c.dialContextFunc = tlsDialer.DialContext
}
}
var err error
c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
if err != nil {
return err
}
client, err := smtp.NewClient(c.connection, c.host)
client, err := smtp.NewClient(connection, c.host)
if err != nil {
return err
}
@ -691,7 +695,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
// checkConn makes sure that a required server connection is available and extends the
// connection deadline
func (c *Client) checkConn() error {
if c.connection == nil {
if !c.smtpClient.HasConnection() {
return ErrNoActiveConnection
}
@ -701,7 +705,7 @@ func (c *Client) checkConn() error {
}
}
if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil {
if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil {
return ErrDeadlineExtendFailed
}
return nil
@ -715,7 +719,7 @@ func (c *Client) serverFallbackAddr() string {
// tls tries to make sure that the STARTTLS requirements are satisfied
func (c *Client) tls() error {
if c.connection == nil {
if !c.smtpClient.HasConnection() {
return ErrNoActiveConnection
}
if !c.useSSL && c.tlspolicy != NoTLS {

View file

@ -13,6 +13,8 @@ import (
// Send sends out the mail message
func (c *Client) Send(messages ...*Msg) (returnErr error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if err := c.checkConn(); err != nil {
returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)}
return

View file

@ -30,6 +30,7 @@ import (
"net/textproto"
"os"
"strings"
"time"
"github.com/wneessen/go-mail/log"
)
@ -472,6 +473,19 @@ func (c *Client) SetDSNRcptNotifyOption(d string) {
c.dsnrntype = d
}
// HasConnection checks if the client has an active connection.
// Returns true if the `conn` field is not nil, indicating an active connection.
func (c *Client) HasConnection() bool {
return c.conn != nil
}
func (c *Client) UpdateDeadline(timeout time.Duration) error {
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return fmt.Errorf("smtp: failed to update deadline: %w", err)
}
return nil
}
// debugLog checks if the debug flag is set and if so logs the provided message to
// the log.Logger interface
func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) {