mirror of
https://github.com/wneessen/go-mail.git
synced 2024-11-22 13:50:49 +01:00
Add mutex to Client for thread-safe operations
This commit introduces a RWMutex to the Client struct in the smtp package to ensure thread-safe access to shared resources. Critical sections in methods like Close, StartTLS, and cmd are now protected with appropriate locking mechanisms. This change helps prevent potential race conditions, ensuring consistent and reliable behavior in concurrent environments.
This commit is contained in:
parent
fec2f2075a
commit
fdb80ad9dd
3 changed files with 47 additions and 0 deletions
41
smtp/smtp.go
41
smtp/smtp.go
|
@ -30,6 +30,7 @@ import (
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/wneessen/go-mail/log"
|
"github.com/wneessen/go-mail/log"
|
||||||
|
@ -70,6 +71,10 @@ type Client struct {
|
||||||
// logger will be used for debug logging
|
// logger will be used for debug logging
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
|
|
||||||
|
// mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access
|
||||||
|
// the resource at a time.
|
||||||
|
mutex sync.RWMutex
|
||||||
|
|
||||||
// tls indicates whether the Client is using TLS
|
// tls indicates whether the Client is using TLS
|
||||||
tls bool
|
tls bool
|
||||||
|
|
||||||
|
@ -112,6 +117,9 @@ func NewClient(conn net.Conn, host string) (*Client, error) {
|
||||||
|
|
||||||
// Close closes the connection.
|
// Close closes the connection.
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
return c.Text.Close()
|
return c.Text.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,12 +147,19 @@ func (c *Client) Hello(localName string) error {
|
||||||
if c.didHello {
|
if c.didHello {
|
||||||
return errors.New("smtp: Hello called after other methods")
|
return errors.New("smtp: Hello called after other methods")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.mutex.Lock()
|
||||||
c.localName = localName
|
c.localName = localName
|
||||||
|
c.mutex.Unlock()
|
||||||
|
|
||||||
return c.hello()
|
return c.hello()
|
||||||
}
|
}
|
||||||
|
|
||||||
// cmd is a convenience function that sends a command and returns the response
|
// cmd is a convenience function that sends a command and returns the response
|
||||||
func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) {
|
func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
c.debugLog(log.DirClientToServer, format, args...)
|
c.debugLog(log.DirClientToServer, format, args...)
|
||||||
id, err := c.Text.Cmd(format, args...)
|
id, err := c.Text.Cmd(format, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -160,7 +175,10 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s
|
||||||
// helo sends the HELO greeting to the server. It should be used only when the
|
// helo sends the HELO greeting to the server. It should be used only when the
|
||||||
// server does not support ehlo.
|
// server does not support ehlo.
|
||||||
func (c *Client) helo() error {
|
func (c *Client) helo() error {
|
||||||
|
c.mutex.Lock()
|
||||||
c.ext = nil
|
c.ext = nil
|
||||||
|
c.mutex.Unlock()
|
||||||
|
|
||||||
_, _, err := c.cmd(250, "HELO %s", c.localName)
|
_, _, err := c.cmd(250, "HELO %s", c.localName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -175,9 +193,13 @@ func (c *Client) StartTLS(config *tls.Config) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.mutex.Lock()
|
||||||
c.conn = tls.Client(c.conn, config)
|
c.conn = tls.Client(c.conn, config)
|
||||||
c.Text = textproto.NewConn(c.conn)
|
c.Text = textproto.NewConn(c.conn)
|
||||||
c.tls = true
|
c.tls = true
|
||||||
|
c.mutex.Unlock()
|
||||||
|
|
||||||
return c.ehlo()
|
return c.ehlo()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,6 +207,9 @@ func (c *Client) StartTLS(config *tls.Config) error {
|
||||||
// The return values are their zero values if [Client.StartTLS] did
|
// The return values are their zero values if [Client.StartTLS] did
|
||||||
// not succeed.
|
// not succeed.
|
||||||
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
|
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
|
||||||
|
c.mutex.RLock()
|
||||||
|
defer c.mutex.RUnlock()
|
||||||
|
|
||||||
tc, ok := c.conn.(*tls.Conn)
|
tc, ok := c.conn.(*tls.Conn)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
|
@ -249,7 +274,9 @@ func (c *Client) Auth(a Auth) error {
|
||||||
// abort the AUTH. Not required for XOAUTH2
|
// abort the AUTH. Not required for XOAUTH2
|
||||||
_, _, _ = c.cmd(501, "*")
|
_, _, _ = c.cmd(501, "*")
|
||||||
}
|
}
|
||||||
|
c.mutex.Lock()
|
||||||
_ = c.Quit()
|
_ = c.Quit()
|
||||||
|
c.mutex.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
|
@ -275,6 +302,8 @@ func (c *Client) Mail(from string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cmdStr := "MAIL FROM:<%s>"
|
cmdStr := "MAIL FROM:<%s>"
|
||||||
|
|
||||||
|
c.mutex.RLock()
|
||||||
if c.ext != nil {
|
if c.ext != nil {
|
||||||
if _, ok := c.ext["8BITMIME"]; ok {
|
if _, ok := c.ext["8BITMIME"]; ok {
|
||||||
cmdStr += " BODY=8BITMIME"
|
cmdStr += " BODY=8BITMIME"
|
||||||
|
@ -287,6 +316,8 @@ func (c *Client) Mail(from string) error {
|
||||||
cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype)
|
cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
c.mutex.RUnlock()
|
||||||
|
|
||||||
_, _, err := c.cmd(250, cmdStr, from)
|
_, _, err := c.cmd(250, cmdStr, from)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -298,7 +329,11 @@ func (c *Client) Rcpt(to string) error {
|
||||||
if err := validateLine(to); err != nil {
|
if err := validateLine(to); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.mutex.RLock()
|
||||||
_, ok := c.ext["DSN"]
|
_, ok := c.ext["DSN"]
|
||||||
|
c.mutex.RUnlock()
|
||||||
|
|
||||||
if ok && c.dsnrntype != "" {
|
if ok && c.dsnrntype != "" {
|
||||||
_, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype)
|
_, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype)
|
||||||
return err
|
return err
|
||||||
|
@ -423,6 +458,9 @@ func (c *Client) Extension(ext string) (bool, string) {
|
||||||
return false, ""
|
return false, ""
|
||||||
}
|
}
|
||||||
ext = strings.ToUpper(ext)
|
ext = strings.ToUpper(ext)
|
||||||
|
|
||||||
|
c.mutex.RLock()
|
||||||
|
defer c.mutex.RUnlock()
|
||||||
param, ok := c.ext[ext]
|
param, ok := c.ext[ext]
|
||||||
return ok, param
|
return ok, param
|
||||||
}
|
}
|
||||||
|
@ -497,6 +535,9 @@ func (c *Client) HasConnection() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) UpdateDeadline(timeout time.Duration) error {
|
func (c *Client) UpdateDeadline(timeout time.Duration) error {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||||
return fmt.Errorf("smtp: failed to update deadline: %w", err)
|
return fmt.Errorf("smtp: failed to update deadline: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,9 @@ func (c *Client) ehlo() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
ext := make(map[string]string)
|
ext := make(map[string]string)
|
||||||
extList := strings.Split(msg, "\n")
|
extList := strings.Split(msg, "\n")
|
||||||
if len(extList) > 1 {
|
if len(extList) > 1 {
|
||||||
|
|
|
@ -28,6 +28,9 @@ func (c *Client) ehlo() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
ext := make(map[string]string)
|
ext := make(map[string]string)
|
||||||
extList := strings.Split(msg, "\n")
|
extList := strings.Split(msg, "\n")
|
||||||
if len(extList) > 1 {
|
if len(extList) > 1 {
|
||||||
|
|
Loading…
Reference in a new issue