mirror of
https://github.com/wneessen/go-mail.git
synced 2024-11-14 18:02:55 +01:00
Add mutex to Client for thread-safe operations
This commit introduces a RWMutex to the Client struct in the smtp package to ensure thread-safe access to shared resources. Critical sections in methods like Close, StartTLS, and cmd are now protected with appropriate locking mechanisms. This change helps prevent potential race conditions, ensuring consistent and reliable behavior in concurrent environments.
This commit is contained in:
parent
fec2f2075a
commit
fdb80ad9dd
3 changed files with 47 additions and 0 deletions
41
smtp/smtp.go
41
smtp/smtp.go
|
@ -30,6 +30,7 @@ import (
|
|||
"net/textproto"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/wneessen/go-mail/log"
|
||||
|
@ -70,6 +71,10 @@ type Client struct {
|
|||
// logger will be used for debug logging
|
||||
logger log.Logger
|
||||
|
||||
// mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access
|
||||
// the resource at a time.
|
||||
mutex sync.RWMutex
|
||||
|
||||
// tls indicates whether the Client is using TLS
|
||||
tls bool
|
||||
|
||||
|
@ -112,6 +117,9 @@ func NewClient(conn net.Conn, host string) (*Client, error) {
|
|||
|
||||
// Close closes the connection.
|
||||
func (c *Client) Close() error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
return c.Text.Close()
|
||||
}
|
||||
|
||||
|
@ -139,12 +147,19 @@ func (c *Client) Hello(localName string) error {
|
|||
if c.didHello {
|
||||
return errors.New("smtp: Hello called after other methods")
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.localName = localName
|
||||
c.mutex.Unlock()
|
||||
|
||||
return c.hello()
|
||||
}
|
||||
|
||||
// cmd is a convenience function that sends a command and returns the response
|
||||
func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.debugLog(log.DirClientToServer, format, args...)
|
||||
id, err := c.Text.Cmd(format, args...)
|
||||
if err != nil {
|
||||
|
@ -160,7 +175,10 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s
|
|||
// helo sends the HELO greeting to the server. It should be used only when the
|
||||
// server does not support ehlo.
|
||||
func (c *Client) helo() error {
|
||||
c.mutex.Lock()
|
||||
c.ext = nil
|
||||
c.mutex.Unlock()
|
||||
|
||||
_, _, err := c.cmd(250, "HELO %s", c.localName)
|
||||
return err
|
||||
}
|
||||
|
@ -175,9 +193,13 @@ func (c *Client) StartTLS(config *tls.Config) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.conn = tls.Client(c.conn, config)
|
||||
c.Text = textproto.NewConn(c.conn)
|
||||
c.tls = true
|
||||
c.mutex.Unlock()
|
||||
|
||||
return c.ehlo()
|
||||
}
|
||||
|
||||
|
@ -185,6 +207,9 @@ func (c *Client) StartTLS(config *tls.Config) error {
|
|||
// The return values are their zero values if [Client.StartTLS] did
|
||||
// not succeed.
|
||||
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
tc, ok := c.conn.(*tls.Conn)
|
||||
if !ok {
|
||||
return
|
||||
|
@ -249,7 +274,9 @@ func (c *Client) Auth(a Auth) error {
|
|||
// abort the AUTH. Not required for XOAUTH2
|
||||
_, _, _ = c.cmd(501, "*")
|
||||
}
|
||||
c.mutex.Lock()
|
||||
_ = c.Quit()
|
||||
c.mutex.Unlock()
|
||||
break
|
||||
}
|
||||
if resp == nil {
|
||||
|
@ -275,6 +302,8 @@ func (c *Client) Mail(from string) error {
|
|||
return err
|
||||
}
|
||||
cmdStr := "MAIL FROM:<%s>"
|
||||
|
||||
c.mutex.RLock()
|
||||
if c.ext != nil {
|
||||
if _, ok := c.ext["8BITMIME"]; ok {
|
||||
cmdStr += " BODY=8BITMIME"
|
||||
|
@ -287,6 +316,8 @@ func (c *Client) Mail(from string) error {
|
|||
cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype)
|
||||
}
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
_, _, err := c.cmd(250, cmdStr, from)
|
||||
return err
|
||||
}
|
||||
|
@ -298,7 +329,11 @@ func (c *Client) Rcpt(to string) error {
|
|||
if err := validateLine(to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
_, ok := c.ext["DSN"]
|
||||
c.mutex.RUnlock()
|
||||
|
||||
if ok && c.dsnrntype != "" {
|
||||
_, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype)
|
||||
return err
|
||||
|
@ -423,6 +458,9 @@ func (c *Client) Extension(ext string) (bool, string) {
|
|||
return false, ""
|
||||
}
|
||||
ext = strings.ToUpper(ext)
|
||||
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
param, ok := c.ext[ext]
|
||||
return ok, param
|
||||
}
|
||||
|
@ -497,6 +535,9 @@ func (c *Client) HasConnection() bool {
|
|||
}
|
||||
|
||||
func (c *Client) UpdateDeadline(timeout time.Duration) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return fmt.Errorf("smtp: failed to update deadline: %w", err)
|
||||
}
|
||||
|
|
|
@ -25,6 +25,9 @@ func (c *Client) ehlo() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
ext := make(map[string]string)
|
||||
extList := strings.Split(msg, "\n")
|
||||
if len(extList) > 1 {
|
||||
|
|
|
@ -28,6 +28,9 @@ func (c *Client) ehlo() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
ext := make(map[string]string)
|
||||
extList := strings.Split(msg, "\n")
|
||||
if len(extList) > 1 {
|
||||
|
|
Loading…
Reference in a new issue