From 6bd9a9c73584d3a9faee41f7f870718e9e28f94d Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 14:03:26 +0200 Subject: [PATCH] Refactor mutex usage for connection safety This commit revises locking mechanism usage around connection operations to avoid potential deadlocks and improve code clarity. Specifically, defer statements were removed and explicit unlocks were added to ensure that mutexes are properly released after critical sections. This change affects several methods, including `Close`, `cmd`, `TLSConnectionState`, `UpdateDeadline`, and newly introduced locking for concurrent data writes and reads in `dataCloser`. --- smtp/smtp.go | 49 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/smtp/smtp.go b/smtp/smtp.go index 1f1d603..379f5fe 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -118,9 +118,9 @@ func NewClient(conn net.Conn, host string) (*Client, error) { // Close closes the connection. func (c *Client) Close() error { c.mutex.Lock() - defer c.mutex.Unlock() - - return c.Text.Close() + err := c.Text.Close() + c.mutex.Unlock() + return err } // hello runs a hello exchange if needed. @@ -158,17 +158,18 @@ func (c *Client) Hello(localName string) error { // cmd is a convenience function that sends a command and returns the response func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { c.mutex.Lock() - defer c.mutex.Unlock() c.debugLog(log.DirClientToServer, format, args...) id, err := c.Text.Cmd(format, args...) if err != nil { + c.mutex.Unlock() return 0, "", err } c.Text.StartResponse(id) - defer c.Text.EndResponse(id) code, msg, err := c.Text.ReadResponse(expectCode) c.debugLog(log.DirServerToClient, "%d %s", code, msg) + c.Text.EndResponse(id) + c.mutex.Unlock() return code, msg, err } @@ -208,13 +209,14 @@ func (c *Client) StartTLS(config *tls.Config) error { // not succeed. func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { c.mutex.RLock() - defer c.mutex.RUnlock() tc, ok := c.conn.(*tls.Conn) if !ok { return } - return tc.ConnectionState(), true + state, ok = tc.ConnectionState(), true + c.mutex.RUnlock() + return } // Verify checks the validity of an email address on the server. @@ -274,9 +276,7 @@ func (c *Client) Auth(a Auth) error { // abort the AUTH. Not required for XOAUTH2 _, _, _ = c.cmd(501, "*") } - c.mutex.Lock() _ = c.Quit() - c.mutex.Unlock() break } if resp == nil { @@ -347,12 +347,23 @@ type dataCloser struct { io.WriteCloser } +// Close releases the lock, closes the WriteCloser, waits for a response, and then returns any error encountered. func (d *dataCloser) Close() error { + d.c.mutex.Lock() _ = d.WriteCloser.Close() _, _, err := d.c.Text.ReadResponse(250) + d.c.mutex.Unlock() return err } +// Write writes data to the underlying WriteCloser while ensuring thread-safety by locking and unlocking a mutex. +func (d *dataCloser) Write(p []byte) (n int, err error) { + d.c.mutex.Lock() + n, err = d.WriteCloser.Write(p) + d.c.mutex.Unlock() + return +} + // Data issues a DATA command to the server and returns a writer that // can be used to write the mail headers and body. The caller should // close the writer before calling any more methods on c. A call to @@ -362,7 +373,14 @@ func (c *Client) Data() (io.WriteCloser, error) { if err != nil { return nil, err } - return &dataCloser{c, c.Text.DotWriter()}, nil + datacloser := &dataCloser{} + + c.mutex.Lock() + datacloser.c = c + datacloser.WriteCloser = c.Text.DotWriter() + c.mutex.Unlock() + + return datacloser, nil } var testHookStartTLS func(*tls.Config) // nil, except for tests @@ -460,8 +478,8 @@ func (c *Client) Extension(ext string) (bool, string) { ext = strings.ToUpper(ext) c.mutex.RLock() - defer c.mutex.RUnlock() param, ok := c.ext[ext] + c.mutex.RUnlock() return ok, param } @@ -494,7 +512,11 @@ func (c *Client) Quit() error { if err != nil { return err } - return c.Text.Close() + c.mutex.Lock() + err = c.Text.Close() + c.mutex.Unlock() + + return err } // SetDebugLog enables the debug logging for incoming and outgoing SMTP messages @@ -536,11 +558,10 @@ func (c *Client) HasConnection() bool { func (c *Client) UpdateDeadline(timeout time.Duration) error { c.mutex.Lock() - defer c.mutex.Unlock() - if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { return fmt.Errorf("smtp: failed to update deadline: %w", err) } + c.mutex.Unlock() return nil }