Add mutex to Client for thread-safe operations

This commit introduces a RWMutex to the Client struct in the smtp package to ensure thread-safe access to shared resources. Critical sections in methods like Close, StartTLS, and cmd are now protected with appropriate locking mechanisms. This change helps prevent potential race conditions, ensuring consistent and reliable behavior in concurrent environments.
This commit is contained in:
Winni Neessen 2024-09-27 11:10:23 +02:00
parent fec2f2075a
commit fdb80ad9dd
Signed by: wneessen
GPG key ID: 385AC9889632126E
3 changed files with 47 additions and 0 deletions

View file

@ -30,6 +30,7 @@ import (
"net/textproto" "net/textproto"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
"github.com/wneessen/go-mail/log" "github.com/wneessen/go-mail/log"
@ -70,6 +71,10 @@ type Client struct {
// logger will be used for debug logging // logger will be used for debug logging
logger log.Logger logger log.Logger
// mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access
// the resource at a time.
mutex sync.RWMutex
// tls indicates whether the Client is using TLS // tls indicates whether the Client is using TLS
tls bool tls bool
@ -112,6 +117,9 @@ func NewClient(conn net.Conn, host string) (*Client, error) {
// Close closes the connection. // Close closes the connection.
func (c *Client) Close() error { func (c *Client) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
return c.Text.Close() return c.Text.Close()
} }
@ -139,12 +147,19 @@ func (c *Client) Hello(localName string) error {
if c.didHello { if c.didHello {
return errors.New("smtp: Hello called after other methods") return errors.New("smtp: Hello called after other methods")
} }
c.mutex.Lock()
c.localName = localName c.localName = localName
c.mutex.Unlock()
return c.hello() return c.hello()
} }
// cmd is a convenience function that sends a command and returns the response // cmd is a convenience function that sends a command and returns the response
func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.debugLog(log.DirClientToServer, format, args...) c.debugLog(log.DirClientToServer, format, args...)
id, err := c.Text.Cmd(format, args...) id, err := c.Text.Cmd(format, args...)
if err != nil { if err != nil {
@ -160,7 +175,10 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s
// helo sends the HELO greeting to the server. It should be used only when the // helo sends the HELO greeting to the server. It should be used only when the
// server does not support ehlo. // server does not support ehlo.
func (c *Client) helo() error { func (c *Client) helo() error {
c.mutex.Lock()
c.ext = nil c.ext = nil
c.mutex.Unlock()
_, _, err := c.cmd(250, "HELO %s", c.localName) _, _, err := c.cmd(250, "HELO %s", c.localName)
return err return err
} }
@ -175,9 +193,13 @@ func (c *Client) StartTLS(config *tls.Config) error {
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock()
c.conn = tls.Client(c.conn, config) c.conn = tls.Client(c.conn, config)
c.Text = textproto.NewConn(c.conn) c.Text = textproto.NewConn(c.conn)
c.tls = true c.tls = true
c.mutex.Unlock()
return c.ehlo() return c.ehlo()
} }
@ -185,6 +207,9 @@ func (c *Client) StartTLS(config *tls.Config) error {
// The return values are their zero values if [Client.StartTLS] did // The return values are their zero values if [Client.StartTLS] did
// not succeed. // not succeed.
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
tc, ok := c.conn.(*tls.Conn) tc, ok := c.conn.(*tls.Conn)
if !ok { if !ok {
return return
@ -249,7 +274,9 @@ func (c *Client) Auth(a Auth) error {
// abort the AUTH. Not required for XOAUTH2 // abort the AUTH. Not required for XOAUTH2
_, _, _ = c.cmd(501, "*") _, _, _ = c.cmd(501, "*")
} }
c.mutex.Lock()
_ = c.Quit() _ = c.Quit()
c.mutex.Unlock()
break break
} }
if resp == nil { if resp == nil {
@ -275,6 +302,8 @@ func (c *Client) Mail(from string) error {
return err return err
} }
cmdStr := "MAIL FROM:<%s>" cmdStr := "MAIL FROM:<%s>"
c.mutex.RLock()
if c.ext != nil { if c.ext != nil {
if _, ok := c.ext["8BITMIME"]; ok { if _, ok := c.ext["8BITMIME"]; ok {
cmdStr += " BODY=8BITMIME" cmdStr += " BODY=8BITMIME"
@ -287,6 +316,8 @@ func (c *Client) Mail(from string) error {
cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype) cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype)
} }
} }
c.mutex.RUnlock()
_, _, err := c.cmd(250, cmdStr, from) _, _, err := c.cmd(250, cmdStr, from)
return err return err
} }
@ -298,7 +329,11 @@ func (c *Client) Rcpt(to string) error {
if err := validateLine(to); err != nil { if err := validateLine(to); err != nil {
return err return err
} }
c.mutex.RLock()
_, ok := c.ext["DSN"] _, ok := c.ext["DSN"]
c.mutex.RUnlock()
if ok && c.dsnrntype != "" { if ok && c.dsnrntype != "" {
_, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype) _, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype)
return err return err
@ -423,6 +458,9 @@ func (c *Client) Extension(ext string) (bool, string) {
return false, "" return false, ""
} }
ext = strings.ToUpper(ext) ext = strings.ToUpper(ext)
c.mutex.RLock()
defer c.mutex.RUnlock()
param, ok := c.ext[ext] param, ok := c.ext[ext]
return ok, param return ok, param
} }
@ -497,6 +535,9 @@ func (c *Client) HasConnection() bool {
} }
func (c *Client) UpdateDeadline(timeout time.Duration) error { func (c *Client) UpdateDeadline(timeout time.Duration) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return fmt.Errorf("smtp: failed to update deadline: %w", err) return fmt.Errorf("smtp: failed to update deadline: %w", err)
} }

View file

@ -25,6 +25,9 @@ func (c *Client) ehlo() error {
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock()
defer c.mutex.Unlock()
ext := make(map[string]string) ext := make(map[string]string)
extList := strings.Split(msg, "\n") extList := strings.Split(msg, "\n")
if len(extList) > 1 { if len(extList) > 1 {

View file

@ -28,6 +28,9 @@ func (c *Client) ehlo() error {
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock()
defer c.mutex.Unlock()
ext := make(map[string]string) ext := make(map[string]string)
extList := strings.Split(msg, "\n") extList := strings.Split(msg, "\n")
if len(extList) > 1 { if len(extList) > 1 {