Fix connection handling and improve thread-safety in SMTP client

Reset connections to nil after Close, add RLock in HasConnection, and refine Close logic to handle already closed connections gracefully. Enhanced DialWithContext documentation and added tests for double-close scenarios to ensure robustness.
This commit is contained in:
Winni Neessen 2024-10-04 23:15:01 +02:00
parent dfdadc5da2
commit adcb8ac41d
Signed by: wneessen
GPG key ID: 385AC9889632126E
3 changed files with 52 additions and 15 deletions

View file

@ -648,7 +648,17 @@ func (c *Client) SetSMTPAuthCustom(smtpAuth smtp.Auth) {
c.smtpAuthType = SMTPAuthCustom
}
// DialWithContext establishes a connection to the SMTP server with a given context.Context
// DialWithContext establishes a connection to the server using the provided context.Context.
//
// Before connecting to the server, the function will add a deadline of the Client's timeout
// to the provided context.Context.
//
// After dialing the DialContextFunc defined in the Client and successfully establishing the
// connection to the SMTP server, it will send the HELO/EHLO SMTP command followed by the
// optional STARTTLS and SMTP AUTH commands. It will also attach the log.Logger in case
// debug logging is enabled on the Client.
//
// From this point in time the Client has an active (cancelable) connection to the SMTP server.
func (c *Client) DialWithContext(dialCtx context.Context) error {
c.mutex.Lock()
defer c.mutex.Unlock()
@ -707,8 +717,9 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
// Close closes the Client connection
func (c *Client) Close() error {
if err := c.checkConn(); err != nil {
return err
// If the connection is already closed, we considered this a no-op and disregard any error.
if !c.smtpClient.HasConnection() {
return nil
}
if err := c.smtpClient.Quit(); err != nil {
return fmt.Errorf("failed to close SMTP client: %w", err)

View file

@ -617,6 +617,32 @@ func TestSetSMTPAuthCustom(t *testing.T) {
}
}
// TestClient_Close_double tests if a close on an already closed connection causes an error.
func TestClient_Close_double(t *testing.T) {
c, err := getTestConnection(true)
if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err)
}
ctx := context.Background()
if err = c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err)
return
}
if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
}
if !c.smtpClient.HasConnection() {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if err = c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
if err = c.Close(); err != nil {
t.Errorf("failed 2nd close connection: %s", err)
}
}
// TestClient_DialWithContext tests the DialWithContext method for the Client object
func TestClient_DialWithContext(t *testing.T) {
c, err := getTestConnection(true)
@ -2391,7 +2417,6 @@ func TestXOAuth2OK_faker(t *testing.T) {
"250 8BITMIME",
"250 OK",
"235 2.7.0 Accepted",
"250 OK",
"221 OK",
}
var wrote strings.Builder
@ -2412,10 +2437,10 @@ func TestXOAuth2OK_faker(t *testing.T) {
if err != nil {
t.Fatalf("unable to create new client: %v", err)
}
if err := c.DialWithContext(context.Background()); err != nil {
if err = c.DialWithContext(context.Background()); err != nil {
t.Fatalf("unexpected dial error: %v", err)
}
if err := c.Close(); err != nil {
if err = c.Close(); err != nil {
t.Fatalf("disconnect from test server failed: %v", err)
}
if !strings.Contains(wrote.String(), "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n") {
@ -2430,7 +2455,6 @@ func TestXOAuth2Unsupported_faker(t *testing.T) {
"250-AUTH LOGIN PLAIN",
"250 8BITMIME",
"250 OK",
"250 OK",
"221 OK",
}
var wrote strings.Builder
@ -2449,18 +2473,18 @@ func TestXOAuth2Unsupported_faker(t *testing.T) {
if err != nil {
t.Fatalf("unable to create new client: %v", err)
}
if err := c.DialWithContext(context.Background()); err == nil {
if err = c.DialWithContext(context.Background()); err == nil {
t.Fatal("expected dial error got nil")
} else {
if !errors.Is(err, ErrXOauth2AuthNotSupported) {
t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err)
}
}
if err := c.Close(); err != nil {
if err = c.Close(); err != nil {
t.Fatalf("disconnect from test server failed: %v", err)
}
client := strings.Split(wrote.String(), "\r\n")
if len(client) != 5 {
if len(client) != 4 {
t.Fatalf("unexpected number of client requests got %d; want 5", len(client))
}
if !strings.HasPrefix(client[0], "EHLO") {
@ -2469,10 +2493,7 @@ func TestXOAuth2Unsupported_faker(t *testing.T) {
if client[1] != "NOOP" {
t.Fatalf("expected NOOP, got %q", client[1])
}
if client[2] != "NOOP" {
t.Fatalf("expected NOOP, got %q", client[2])
}
if client[3] != "QUIT" {
if client[2] != "QUIT" {
t.Fatalf("expected QUIT, got %q", client[3])
}
}

View file

@ -516,6 +516,8 @@ func (c *Client) Quit() error {
}
c.mutex.Lock()
err = c.Text.Close()
c.Text = nil
c.conn = nil
c.mutex.Unlock()
return err
@ -555,7 +557,10 @@ func (c *Client) SetDSNRcptNotifyOption(d string) {
// HasConnection checks if the client has an active connection.
// Returns true if the `conn` field is not nil, indicating an active connection.
func (c *Client) HasConnection() bool {
return c.conn != nil
c.mutex.RLock()
conn := c.conn
c.mutex.RUnlock()
return conn != nil
}
func (c *Client) UpdateDeadline(timeout time.Duration) error {