From fdb80ad9ddc2dc267805fe3201645dc4ca0b72c2 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 11:10:23 +0200 Subject: [PATCH] Add mutex to Client for thread-safe operations This commit introduces a RWMutex to the Client struct in the smtp package to ensure thread-safe access to shared resources. Critical sections in methods like Close, StartTLS, and cmd are now protected with appropriate locking mechanisms. This change helps prevent potential race conditions, ensuring consistent and reliable behavior in concurrent environments. --- smtp/smtp.go | 41 +++++++++++++++++++++++++++++++++++++++++ smtp/smtp_ehlo.go | 3 +++ smtp/smtp_ehlo_117.go | 3 +++ 3 files changed, 47 insertions(+) diff --git a/smtp/smtp.go b/smtp/smtp.go index 787006c..1f1d603 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -30,6 +30,7 @@ import ( "net/textproto" "os" "strings" + "sync" "time" "github.com/wneessen/go-mail/log" @@ -70,6 +71,10 @@ type Client struct { // logger will be used for debug logging logger log.Logger + // mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access + // the resource at a time. + mutex sync.RWMutex + // tls indicates whether the Client is using TLS tls bool @@ -112,6 +117,9 @@ func NewClient(conn net.Conn, host string) (*Client, error) { // Close closes the connection. func (c *Client) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.Text.Close() } @@ -139,12 +147,19 @@ func (c *Client) Hello(localName string) error { if c.didHello { return errors.New("smtp: Hello called after other methods") } + + c.mutex.Lock() c.localName = localName + c.mutex.Unlock() + return c.hello() } // cmd is a convenience function that sends a command and returns the response func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.debugLog(log.DirClientToServer, format, args...) id, err := c.Text.Cmd(format, args...) if err != nil { @@ -160,7 +175,10 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s // helo sends the HELO greeting to the server. It should be used only when the // server does not support ehlo. func (c *Client) helo() error { + c.mutex.Lock() c.ext = nil + c.mutex.Unlock() + _, _, err := c.cmd(250, "HELO %s", c.localName) return err } @@ -175,9 +193,13 @@ func (c *Client) StartTLS(config *tls.Config) error { if err != nil { return err } + + c.mutex.Lock() c.conn = tls.Client(c.conn, config) c.Text = textproto.NewConn(c.conn) c.tls = true + c.mutex.Unlock() + return c.ehlo() } @@ -185,6 +207,9 @@ func (c *Client) StartTLS(config *tls.Config) error { // The return values are their zero values if [Client.StartTLS] did // not succeed. func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { + c.mutex.RLock() + defer c.mutex.RUnlock() + tc, ok := c.conn.(*tls.Conn) if !ok { return @@ -249,7 +274,9 @@ func (c *Client) Auth(a Auth) error { // abort the AUTH. Not required for XOAUTH2 _, _, _ = c.cmd(501, "*") } + c.mutex.Lock() _ = c.Quit() + c.mutex.Unlock() break } if resp == nil { @@ -275,6 +302,8 @@ func (c *Client) Mail(from string) error { return err } cmdStr := "MAIL FROM:<%s>" + + c.mutex.RLock() if c.ext != nil { if _, ok := c.ext["8BITMIME"]; ok { cmdStr += " BODY=8BITMIME" @@ -287,6 +316,8 @@ func (c *Client) Mail(from string) error { cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype) } } + c.mutex.RUnlock() + _, _, err := c.cmd(250, cmdStr, from) return err } @@ -298,7 +329,11 @@ func (c *Client) Rcpt(to string) error { if err := validateLine(to); err != nil { return err } + + c.mutex.RLock() _, ok := c.ext["DSN"] + c.mutex.RUnlock() + if ok && c.dsnrntype != "" { _, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype) return err @@ -423,6 +458,9 @@ func (c *Client) Extension(ext string) (bool, string) { return false, "" } ext = strings.ToUpper(ext) + + c.mutex.RLock() + defer c.mutex.RUnlock() param, ok := c.ext[ext] return ok, param } @@ -497,6 +535,9 @@ func (c *Client) HasConnection() bool { } func (c *Client) UpdateDeadline(timeout time.Duration) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { return fmt.Errorf("smtp: failed to update deadline: %w", err) } diff --git a/smtp/smtp_ehlo.go b/smtp/smtp_ehlo.go index ae80a62..457be57 100644 --- a/smtp/smtp_ehlo.go +++ b/smtp/smtp_ehlo.go @@ -25,6 +25,9 @@ func (c *Client) ehlo() error { if err != nil { return err } + + c.mutex.Lock() + defer c.mutex.Unlock() ext := make(map[string]string) extList := strings.Split(msg, "\n") if len(extList) > 1 { diff --git a/smtp/smtp_ehlo_117.go b/smtp/smtp_ehlo_117.go index c516a36..c40297f 100644 --- a/smtp/smtp_ehlo_117.go +++ b/smtp/smtp_ehlo_117.go @@ -28,6 +28,9 @@ func (c *Client) ehlo() error { if err != nil { return err } + + c.mutex.Lock() + defer c.mutex.Unlock() ext := make(map[string]string) extList := strings.Split(msg, "\n") if len(extList) > 1 {