Add mutex to Client for thread-safe operations

This commit introduces a RWMutex to the Client struct in the smtp package to ensure thread-safe access to shared resources. Critical sections in methods like Close, StartTLS, and cmd are now protected with appropriate locking mechanisms. This change helps prevent potential race conditions, ensuring consistent and reliable behavior in concurrent environments.
This commit is contained in:
Winni Neessen 2024-09-27 11:10:23 +02:00
parent fec2f2075a
commit fdb80ad9dd
Signed by: wneessen
GPG key ID: 385AC9889632126E
3 changed files with 47 additions and 0 deletions

View file

@ -30,6 +30,7 @@ import (
"net/textproto"
"os"
"strings"
"sync"
"time"
"github.com/wneessen/go-mail/log"
@ -70,6 +71,10 @@ type Client struct {
// logger will be used for debug logging
logger log.Logger
// mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access
// the resource at a time.
mutex sync.RWMutex
// tls indicates whether the Client is using TLS
tls bool
@ -112,6 +117,9 @@ func NewClient(conn net.Conn, host string) (*Client, error) {
// Close closes the connection.
func (c *Client) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.Text.Close()
}
@ -139,12 +147,19 @@ func (c *Client) Hello(localName string) error {
if c.didHello {
return errors.New("smtp: Hello called after other methods")
}
c.mutex.Lock()
c.localName = localName
c.mutex.Unlock()
return c.hello()
}
// cmd is a convenience function that sends a command and returns the response
func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.debugLog(log.DirClientToServer, format, args...)
id, err := c.Text.Cmd(format, args...)
if err != nil {
@ -160,7 +175,10 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s
// helo sends the HELO greeting to the server. It should be used only when the
// server does not support ehlo.
func (c *Client) helo() error {
c.mutex.Lock()
c.ext = nil
c.mutex.Unlock()
_, _, err := c.cmd(250, "HELO %s", c.localName)
return err
}
@ -175,9 +193,13 @@ func (c *Client) StartTLS(config *tls.Config) error {
if err != nil {
return err
}
c.mutex.Lock()
c.conn = tls.Client(c.conn, config)
c.Text = textproto.NewConn(c.conn)
c.tls = true
c.mutex.Unlock()
return c.ehlo()
}
@ -185,6 +207,9 @@ func (c *Client) StartTLS(config *tls.Config) error {
// The return values are their zero values if [Client.StartTLS] did
// not succeed.
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
tc, ok := c.conn.(*tls.Conn)
if !ok {
return
@ -249,7 +274,9 @@ func (c *Client) Auth(a Auth) error {
// abort the AUTH. Not required for XOAUTH2
_, _, _ = c.cmd(501, "*")
}
c.mutex.Lock()
_ = c.Quit()
c.mutex.Unlock()
break
}
if resp == nil {
@ -275,6 +302,8 @@ func (c *Client) Mail(from string) error {
return err
}
cmdStr := "MAIL FROM:<%s>"
c.mutex.RLock()
if c.ext != nil {
if _, ok := c.ext["8BITMIME"]; ok {
cmdStr += " BODY=8BITMIME"
@ -287,6 +316,8 @@ func (c *Client) Mail(from string) error {
cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype)
}
}
c.mutex.RUnlock()
_, _, err := c.cmd(250, cmdStr, from)
return err
}
@ -298,7 +329,11 @@ func (c *Client) Rcpt(to string) error {
if err := validateLine(to); err != nil {
return err
}
c.mutex.RLock()
_, ok := c.ext["DSN"]
c.mutex.RUnlock()
if ok && c.dsnrntype != "" {
_, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype)
return err
@ -423,6 +458,9 @@ func (c *Client) Extension(ext string) (bool, string) {
return false, ""
}
ext = strings.ToUpper(ext)
c.mutex.RLock()
defer c.mutex.RUnlock()
param, ok := c.ext[ext]
return ok, param
}
@ -497,6 +535,9 @@ func (c *Client) HasConnection() bool {
}
func (c *Client) UpdateDeadline(timeout time.Duration) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return fmt.Errorf("smtp: failed to update deadline: %w", err)
}

View file

@ -25,6 +25,9 @@ func (c *Client) ehlo() error {
if err != nil {
return err
}
c.mutex.Lock()
defer c.mutex.Unlock()
ext := make(map[string]string)
extList := strings.Split(msg, "\n")
if len(extList) > 1 {

View file

@ -28,6 +28,9 @@ func (c *Client) ehlo() error {
if err != nil {
return err
}
c.mutex.Lock()
defer c.mutex.Unlock()
ext := make(map[string]string)
extList := strings.Split(msg, "\n")
if len(extList) > 1 {