diff --git a/smtp/smtp.go b/smtp/smtp.go index 787006c..1f1d603 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -30,6 +30,7 @@ import ( "net/textproto" "os" "strings" + "sync" "time" "github.com/wneessen/go-mail/log" @@ -70,6 +71,10 @@ type Client struct { // logger will be used for debug logging logger log.Logger + // mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access + // the resource at a time. + mutex sync.RWMutex + // tls indicates whether the Client is using TLS tls bool @@ -112,6 +117,9 @@ func NewClient(conn net.Conn, host string) (*Client, error) { // Close closes the connection. func (c *Client) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.Text.Close() } @@ -139,12 +147,19 @@ func (c *Client) Hello(localName string) error { if c.didHello { return errors.New("smtp: Hello called after other methods") } + + c.mutex.Lock() c.localName = localName + c.mutex.Unlock() + return c.hello() } // cmd is a convenience function that sends a command and returns the response func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.debugLog(log.DirClientToServer, format, args...) id, err := c.Text.Cmd(format, args...) if err != nil { @@ -160,7 +175,10 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s // helo sends the HELO greeting to the server. It should be used only when the // server does not support ehlo. func (c *Client) helo() error { + c.mutex.Lock() c.ext = nil + c.mutex.Unlock() + _, _, err := c.cmd(250, "HELO %s", c.localName) return err } @@ -175,9 +193,13 @@ func (c *Client) StartTLS(config *tls.Config) error { if err != nil { return err } + + c.mutex.Lock() c.conn = tls.Client(c.conn, config) c.Text = textproto.NewConn(c.conn) c.tls = true + c.mutex.Unlock() + return c.ehlo() } @@ -185,6 +207,9 @@ func (c *Client) StartTLS(config *tls.Config) error { // The return values are their zero values if [Client.StartTLS] did // not succeed. func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { + c.mutex.RLock() + defer c.mutex.RUnlock() + tc, ok := c.conn.(*tls.Conn) if !ok { return @@ -249,7 +274,9 @@ func (c *Client) Auth(a Auth) error { // abort the AUTH. Not required for XOAUTH2 _, _, _ = c.cmd(501, "*") } + c.mutex.Lock() _ = c.Quit() + c.mutex.Unlock() break } if resp == nil { @@ -275,6 +302,8 @@ func (c *Client) Mail(from string) error { return err } cmdStr := "MAIL FROM:<%s>" + + c.mutex.RLock() if c.ext != nil { if _, ok := c.ext["8BITMIME"]; ok { cmdStr += " BODY=8BITMIME" @@ -287,6 +316,8 @@ func (c *Client) Mail(from string) error { cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype) } } + c.mutex.RUnlock() + _, _, err := c.cmd(250, cmdStr, from) return err } @@ -298,7 +329,11 @@ func (c *Client) Rcpt(to string) error { if err := validateLine(to); err != nil { return err } + + c.mutex.RLock() _, ok := c.ext["DSN"] + c.mutex.RUnlock() + if ok && c.dsnrntype != "" { _, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype) return err @@ -423,6 +458,9 @@ func (c *Client) Extension(ext string) (bool, string) { return false, "" } ext = strings.ToUpper(ext) + + c.mutex.RLock() + defer c.mutex.RUnlock() param, ok := c.ext[ext] return ok, param } @@ -497,6 +535,9 @@ func (c *Client) HasConnection() bool { } func (c *Client) UpdateDeadline(timeout time.Duration) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { return fmt.Errorf("smtp: failed to update deadline: %w", err) } diff --git a/smtp/smtp_ehlo.go b/smtp/smtp_ehlo.go index ae80a62..457be57 100644 --- a/smtp/smtp_ehlo.go +++ b/smtp/smtp_ehlo.go @@ -25,6 +25,9 @@ func (c *Client) ehlo() error { if err != nil { return err } + + c.mutex.Lock() + defer c.mutex.Unlock() ext := make(map[string]string) extList := strings.Split(msg, "\n") if len(extList) > 1 { diff --git a/smtp/smtp_ehlo_117.go b/smtp/smtp_ehlo_117.go index c516a36..c40297f 100644 --- a/smtp/smtp_ehlo_117.go +++ b/smtp/smtp_ehlo_117.go @@ -28,6 +28,9 @@ func (c *Client) ehlo() error { if err != nil { return err } + + c.mutex.Lock() + defer c.mutex.Unlock() ext := make(map[string]string) extList := strings.Split(msg, "\n") if len(extList) > 1 {