From adcb8ac41dede831a53691e71167a0b867288a01 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 4 Oct 2024 23:15:01 +0200 Subject: [PATCH] Fix connection handling and improve thread-safety in SMTP client Reset connections to nil after Close, add RLock in HasConnection, and refine Close logic to handle already closed connections gracefully. Enhanced DialWithContext documentation and added tests for double-close scenarios to ensure robustness. --- client.go | 17 ++++++++++++++--- client_test.go | 43 ++++++++++++++++++++++++++++++++----------- smtp/smtp.go | 7 ++++++- 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index 48d6731..787fafa 100644 --- a/client.go +++ b/client.go @@ -648,7 +648,17 @@ func (c *Client) SetSMTPAuthCustom(smtpAuth smtp.Auth) { c.smtpAuthType = SMTPAuthCustom } -// DialWithContext establishes a connection to the SMTP server with a given context.Context +// DialWithContext establishes a connection to the server using the provided context.Context. +// +// Before connecting to the server, the function will add a deadline of the Client's timeout +// to the provided context.Context. +// +// After dialing the DialContextFunc defined in the Client and successfully establishing the +// connection to the SMTP server, it will send the HELO/EHLO SMTP command followed by the +// optional STARTTLS and SMTP AUTH commands. It will also attach the log.Logger in case +// debug logging is enabled on the Client. +// +// From this point in time the Client has an active (cancelable) connection to the SMTP server. func (c *Client) DialWithContext(dialCtx context.Context) error { c.mutex.Lock() defer c.mutex.Unlock() @@ -707,8 +717,9 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { // Close closes the Client connection func (c *Client) Close() error { - if err := c.checkConn(); err != nil { - return err + // If the connection is already closed, we considered this a no-op and disregard any error. + if !c.smtpClient.HasConnection() { + return nil } if err := c.smtpClient.Quit(); err != nil { return fmt.Errorf("failed to close SMTP client: %w", err) diff --git a/client_test.go b/client_test.go index d8dc87f..c767e75 100644 --- a/client_test.go +++ b/client_test.go @@ -617,6 +617,32 @@ func TestSetSMTPAuthCustom(t *testing.T) { } } +// TestClient_Close_double tests if a close on an already closed connection causes an error. +func TestClient_Close_double(t *testing.T) { + c, err := getTestConnection(true) + if err != nil { + t.Skipf("failed to create test client: %s. Skipping tests", err) + } + ctx := context.Background() + if err = c.DialWithContext(ctx); err != nil { + t.Errorf("failed to dial with context: %s", err) + return + } + if c.smtpClient == nil { + t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return + } + if !c.smtpClient.HasConnection() { + t.Errorf("DialWithContext didn't fail but no connection found.") + } + if err = c.Close(); err != nil { + t.Errorf("failed to close connection: %s", err) + } + if err = c.Close(); err != nil { + t.Errorf("failed 2nd close connection: %s", err) + } +} + // TestClient_DialWithContext tests the DialWithContext method for the Client object func TestClient_DialWithContext(t *testing.T) { c, err := getTestConnection(true) @@ -2391,7 +2417,6 @@ func TestXOAuth2OK_faker(t *testing.T) { "250 8BITMIME", "250 OK", "235 2.7.0 Accepted", - "250 OK", "221 OK", } var wrote strings.Builder @@ -2412,10 +2437,10 @@ func TestXOAuth2OK_faker(t *testing.T) { if err != nil { t.Fatalf("unable to create new client: %v", err) } - if err := c.DialWithContext(context.Background()); err != nil { + if err = c.DialWithContext(context.Background()); err != nil { t.Fatalf("unexpected dial error: %v", err) } - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Fatalf("disconnect from test server failed: %v", err) } if !strings.Contains(wrote.String(), "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n") { @@ -2430,7 +2455,6 @@ func TestXOAuth2Unsupported_faker(t *testing.T) { "250-AUTH LOGIN PLAIN", "250 8BITMIME", "250 OK", - "250 OK", "221 OK", } var wrote strings.Builder @@ -2449,18 +2473,18 @@ func TestXOAuth2Unsupported_faker(t *testing.T) { if err != nil { t.Fatalf("unable to create new client: %v", err) } - if err := c.DialWithContext(context.Background()); err == nil { + if err = c.DialWithContext(context.Background()); err == nil { t.Fatal("expected dial error got nil") } else { if !errors.Is(err, ErrXOauth2AuthNotSupported) { t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) } } - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Fatalf("disconnect from test server failed: %v", err) } client := strings.Split(wrote.String(), "\r\n") - if len(client) != 5 { + if len(client) != 4 { t.Fatalf("unexpected number of client requests got %d; want 5", len(client)) } if !strings.HasPrefix(client[0], "EHLO") { @@ -2469,10 +2493,7 @@ func TestXOAuth2Unsupported_faker(t *testing.T) { if client[1] != "NOOP" { t.Fatalf("expected NOOP, got %q", client[1]) } - if client[2] != "NOOP" { - t.Fatalf("expected NOOP, got %q", client[2]) - } - if client[3] != "QUIT" { + if client[2] != "QUIT" { t.Fatalf("expected QUIT, got %q", client[3]) } } diff --git a/smtp/smtp.go b/smtp/smtp.go index f9961c9..ce163a2 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -516,6 +516,8 @@ func (c *Client) Quit() error { } c.mutex.Lock() err = c.Text.Close() + c.Text = nil + c.conn = nil c.mutex.Unlock() return err @@ -555,7 +557,10 @@ func (c *Client) SetDSNRcptNotifyOption(d string) { // HasConnection checks if the client has an active connection. // Returns true if the `conn` field is not nil, indicating an active connection. func (c *Client) HasConnection() bool { - return c.conn != nil + c.mutex.RLock() + conn := c.conn + c.mutex.RUnlock() + return conn != nil } func (c *Client) UpdateDeadline(timeout time.Duration) error {