diff --git a/client.go b/client.go index 48d6731..787fafa 100644 --- a/client.go +++ b/client.go @@ -648,7 +648,17 @@ func (c *Client) SetSMTPAuthCustom(smtpAuth smtp.Auth) { c.smtpAuthType = SMTPAuthCustom } -// DialWithContext establishes a connection to the SMTP server with a given context.Context +// DialWithContext establishes a connection to the server using the provided context.Context. +// +// Before connecting to the server, the function will add a deadline of the Client's timeout +// to the provided context.Context. +// +// After dialing the DialContextFunc defined in the Client and successfully establishing the +// connection to the SMTP server, it will send the HELO/EHLO SMTP command followed by the +// optional STARTTLS and SMTP AUTH commands. It will also attach the log.Logger in case +// debug logging is enabled on the Client. +// +// From this point in time the Client has an active (cancelable) connection to the SMTP server. func (c *Client) DialWithContext(dialCtx context.Context) error { c.mutex.Lock() defer c.mutex.Unlock() @@ -707,8 +717,9 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { // Close closes the Client connection func (c *Client) Close() error { - if err := c.checkConn(); err != nil { - return err + // If the connection is already closed, we considered this a no-op and disregard any error. + if !c.smtpClient.HasConnection() { + return nil } if err := c.smtpClient.Quit(); err != nil { return fmt.Errorf("failed to close SMTP client: %w", err) diff --git a/client_test.go b/client_test.go index d8dc87f..c767e75 100644 --- a/client_test.go +++ b/client_test.go @@ -617,6 +617,32 @@ func TestSetSMTPAuthCustom(t *testing.T) { } } +// TestClient_Close_double tests if a close on an already closed connection causes an error. +func TestClient_Close_double(t *testing.T) { + c, err := getTestConnection(true) + if err != nil { + t.Skipf("failed to create test client: %s. Skipping tests", err) + } + ctx := context.Background() + if err = c.DialWithContext(ctx); err != nil { + t.Errorf("failed to dial with context: %s", err) + return + } + if c.smtpClient == nil { + t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return + } + if !c.smtpClient.HasConnection() { + t.Errorf("DialWithContext didn't fail but no connection found.") + } + if err = c.Close(); err != nil { + t.Errorf("failed to close connection: %s", err) + } + if err = c.Close(); err != nil { + t.Errorf("failed 2nd close connection: %s", err) + } +} + // TestClient_DialWithContext tests the DialWithContext method for the Client object func TestClient_DialWithContext(t *testing.T) { c, err := getTestConnection(true) @@ -2391,7 +2417,6 @@ func TestXOAuth2OK_faker(t *testing.T) { "250 8BITMIME", "250 OK", "235 2.7.0 Accepted", - "250 OK", "221 OK", } var wrote strings.Builder @@ -2412,10 +2437,10 @@ func TestXOAuth2OK_faker(t *testing.T) { if err != nil { t.Fatalf("unable to create new client: %v", err) } - if err := c.DialWithContext(context.Background()); err != nil { + if err = c.DialWithContext(context.Background()); err != nil { t.Fatalf("unexpected dial error: %v", err) } - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Fatalf("disconnect from test server failed: %v", err) } if !strings.Contains(wrote.String(), "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n") { @@ -2430,7 +2455,6 @@ func TestXOAuth2Unsupported_faker(t *testing.T) { "250-AUTH LOGIN PLAIN", "250 8BITMIME", "250 OK", - "250 OK", "221 OK", } var wrote strings.Builder @@ -2449,18 +2473,18 @@ func TestXOAuth2Unsupported_faker(t *testing.T) { if err != nil { t.Fatalf("unable to create new client: %v", err) } - if err := c.DialWithContext(context.Background()); err == nil { + if err = c.DialWithContext(context.Background()); err == nil { t.Fatal("expected dial error got nil") } else { if !errors.Is(err, ErrXOauth2AuthNotSupported) { t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) } } - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Fatalf("disconnect from test server failed: %v", err) } client := strings.Split(wrote.String(), "\r\n") - if len(client) != 5 { + if len(client) != 4 { t.Fatalf("unexpected number of client requests got %d; want 5", len(client)) } if !strings.HasPrefix(client[0], "EHLO") { @@ -2469,10 +2493,7 @@ func TestXOAuth2Unsupported_faker(t *testing.T) { if client[1] != "NOOP" { t.Fatalf("expected NOOP, got %q", client[1]) } - if client[2] != "NOOP" { - t.Fatalf("expected NOOP, got %q", client[2]) - } - if client[3] != "QUIT" { + if client[2] != "QUIT" { t.Fatalf("expected QUIT, got %q", client[3]) } } diff --git a/smtp/smtp.go b/smtp/smtp.go index f9961c9..ce163a2 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -516,6 +516,8 @@ func (c *Client) Quit() error { } c.mutex.Lock() err = c.Text.Close() + c.Text = nil + c.conn = nil c.mutex.Unlock() return err @@ -555,7 +557,10 @@ func (c *Client) SetDSNRcptNotifyOption(d string) { // HasConnection checks if the client has an active connection. // Returns true if the `conn` field is not nil, indicating an active connection. func (c *Client) HasConnection() bool { - return c.conn != nil + c.mutex.RLock() + conn := c.conn + c.mutex.RUnlock() + return conn != nil } func (c *Client) UpdateDeadline(timeout time.Duration) error {