From 45776c052f417971339f6dc7aa88ecca11d78ab8 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 14:58:05 +0100 Subject: [PATCH 01/11] Refactor client error handling and add concurrent send tests Updated tests to correctly assert the absence of an SMTP client on failure. Added concurrent sending tests for DialAndSendWithContext to improve test coverage and reliability. Also, refined the `AutoDiscover` and other client methods to ensure proper parameter use. --- client_test.go | 112 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 76 insertions(+), 36 deletions(-) diff --git a/client_test.go b/client_test.go index f6d5161..36fb423 100644 --- a/client_test.go +++ b/client_test.go @@ -1772,11 +1772,8 @@ func TestClient_DialWithContext(t *testing.T) { t.Errorf("failed to close the client: %s", err) } }) - if client.smtpClient == nil { - t.Errorf("client with invalid HELO should still have a smtp client, got nil") - } - if !client.smtpClient.HasConnection() { - t.Errorf("client with invalid HELO should still have a smtp client connection, got nil") + if client.smtpClient != nil { + t.Error("client with invalid HELO should not have a smtp client") } }) t.Run("fail on base port and fallback", func(t *testing.T) { @@ -1825,11 +1822,8 @@ func TestClient_DialWithContext(t *testing.T) { if err = client.DialWithContext(ctxDial); err == nil { t.Fatalf("connection was supposed to fail, but didn't") } - if client.smtpClient == nil { - t.Fatalf("client has no smtp client") - } - if !client.smtpClient.HasConnection() { - t.Errorf("client has no connection") + if client.smtpClient != nil { + t.Fatalf("client is not supposed to have a smtp client") } }) t.Run("connect with failing auth", func(t *testing.T) { @@ -2297,6 +2291,7 @@ func TestClient_DialAndSendWithContext(t *testing.T) { t.Errorf("client was supposed to fail on dial") } }) + // https://github.com/wneessen/go-mail/issues/380 t.Run("concurrent sending via DialAndSendWithContext", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2336,6 +2331,44 @@ func TestClient_DialAndSendWithContext(t *testing.T) { } wg.Wait() }) + // https://github.com/wneessen/go-mail/issues/385 + t.Run("concurrent sending via DialAndSendWithContext on receiver func", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + sender := testSender{client} + + ctxDial := context.Background() + wg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + wg.Add(1) + msg := testMessage(t) + go func() { + defer wg.Done() + if goroutineErr := sender.Send(ctxDial, msg); goroutineErr != nil { + t.Errorf("failed to send message: %s", goroutineErr) + } + }() + } + wg.Wait() + }) } func TestClient_auth(t *testing.T) { @@ -2574,8 +2607,8 @@ func TestClient_authTypeAutoDiscover(t *testing.T) { } for _, tt := range tests { t.Run("AutoDiscover selects the strongest auth type: "+string(tt.expect), func(t *testing.T) { - client := &Client{smtpAuthType: SMTPAuthAutoDiscover, isEncrypted: tt.tls} - authType, err := client.authTypeAutoDiscover(tt.supported) + client := &Client{smtpAuthType: SMTPAuthAutoDiscover} + authType, err := client.authTypeAutoDiscover(tt.supported, tt.tls) if err != nil && !tt.shouldFail { t.Fatalf("failed to auto discover auth type: %s", err) } @@ -2748,7 +2781,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err != nil { + if err = client.sendSingleMsg(client.smtpClient, message); err != nil { t.Errorf("failed to send message: %s", err) } }) @@ -2791,7 +2824,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } }) @@ -2836,7 +2869,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2886,7 +2919,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2936,7 +2969,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2986,7 +3019,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err != nil { + if err = client.sendSingleMsg(client.smtpClient, message); err != nil { t.Errorf("failed to send message: %s", err) } }) @@ -3029,7 +3062,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3080,7 +3113,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3131,7 +3164,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3181,7 +3214,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3231,7 +3264,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3281,7 +3314,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Error("expected mail delivery to fail") } var sendErr *SendError @@ -3334,7 +3367,7 @@ func TestClient_checkConn(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.checkConn(); err != nil { + if err = client.checkConn(client.smtpClient); err != nil { t.Errorf("failed to check connection: %s", err) } }) @@ -3375,7 +3408,7 @@ func TestClient_checkConn(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.checkConn(); err == nil { + if err = client.checkConn(client.smtpClient); err == nil { t.Errorf("client should have failed on connection check") } if !errors.Is(err, ErrNoActiveConnection) { @@ -3387,7 +3420,7 @@ func TestClient_checkConn(t *testing.T) { if err != nil { t.Fatalf("failed to create new client: %s", err) } - if err = client.checkConn(); err == nil { + if err = client.checkConn(client.smtpClient); err == nil { t.Errorf("client should have failed on connection check") } if !errors.Is(err, ErrNoActiveConnection) { @@ -3611,24 +3644,20 @@ func TestClient_XOAuth2OnFaker(t *testing.T) { } if err = c.DialWithContext(context.Background()); err == nil { t.Fatal("expected dial error got nil") - } else { - if !errors.Is(err, ErrXOauth2AuthNotSupported) { - t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) - } + } + if !errors.Is(err, ErrXOauth2AuthNotSupported) { + t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) } if err = c.Close(); err != nil { t.Fatalf("disconnect from test server failed: %v", err) } client := strings.Split(wrote.String(), "\r\n") - if len(client) != 3 { - t.Fatalf("unexpected number of client requests got %d; want 3", len(client)) + if len(client) != 2 { + t.Fatalf("unexpected number of client requests got %d; want 2", len(client)) } if !strings.HasPrefix(client[0], "EHLO") { t.Fatalf("expected EHLO, got %q", client[0]) } - if client[1] != "QUIT" { - t.Fatalf("expected QUIT, got %q", client[3]) - } }) } @@ -3652,6 +3681,17 @@ func (f faker) SetDeadline(time.Time) error { return nil } func (f faker) SetReadDeadline(time.Time) error { return nil } func (f faker) SetWriteDeadline(time.Time) error { return nil } +type testSender struct { + client *Client +} + +func (t *testSender) Send(ctx context.Context, m *Msg) error { + if err := t.client.DialAndSendWithContext(ctx, m); err != nil { + return fmt.Errorf("failed to dial and send mail: %w", err) + } + return nil +} + // parseJSONLog parses a JSON encoded log from the provided buffer and returns a slice of logLine structs. // In case of a decode error, it reports the error to the testing framework. func parseJSONLog(t *testing.T, buf *bytes.Buffer) logData { From f9e869061e9a7085d199be3408cebfaa708f9386 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 14:58:41 +0100 Subject: [PATCH 02/11] Add mutex locks to DSN option setters Mutex locks are added to ensure thread safety when setting DSN mail return and recipient notify options. This prevents data races in concurrent environments, improving the client's robustness. --- smtp/smtp.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/smtp/smtp.go b/smtp/smtp.go index 8a278ee..24a079f 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -601,12 +601,16 @@ func (c *Client) SetLogAuthData() { // SetDSNMailReturnOption sets the DSN mail return option for the Mail method func (c *Client) SetDSNMailReturnOption(d string) { + c.mutex.Lock() c.dsnmrtype = d + c.mutex.Unlock() } // SetDSNRcptNotifyOption sets the DSN recipient notify option for the Mail method func (c *Client) SetDSNRcptNotifyOption(d string) { + c.mutex.Lock() c.dsnrntype = d + c.mutex.Unlock() } // HasConnection checks if the client has an active connection. From 55884786be1d83687bba15ad105202103540c971 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 14:59:19 +0100 Subject: [PATCH 03/11] Refactor SMTP client functions to improve modularity Refactor `DialWithContext` to delegate client creation to `SMTPClientFromDialWithContext`. Add new methods `SendWithSMTPClient`, `CloseWithSMTPClient`, and `ResetWithSMTPClient` to handle specific actions on `smtp.Client`. Simplify `auth`, `sendSingleMsg`, and `tls` methods by passing client as parameters. --- client.go | 261 ++++++++++++++++++++++++++++++++------------------ client_120.go | 14 ++- 2 files changed, 180 insertions(+), 95 deletions(-) diff --git a/client.go b/client.go index a7dead8..de6479e 100644 --- a/client.go +++ b/client.go @@ -142,9 +142,6 @@ type ( // host is the hostname of the SMTP server we are connecting to. host string - // isEncrypted indicates wether the Client connection is encrypted or not. - isEncrypted bool - // logAuthData indicates whether authentication-related data should be logged. logAuthData bool @@ -931,62 +928,131 @@ func (c *Client) SetLogAuthData(logAuth bool) { // Returns: // - An error if the connection to the SMTP server fails or any subsequent command fails. func (c *Client) DialWithContext(dialCtx context.Context) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) - defer cancel() - - if c.dialContextFunc == nil { - netDialer := net.Dialer{} - c.dialContextFunc = netDialer.DialContext - - if c.useSSL { - tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig} - c.isEncrypted = true - c.dialContextFunc = tlsDialer.DialContext - } - } - connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) - if err != nil && c.fallbackPort != 0 { - // TODO: should we somehow log or append the previous error? - connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) - } + client, err := c.SMTPClientFromDialWithContext(dialCtx) if err != nil { return err } + c.mutex.Lock() + c.smtpClient = client + c.mutex.Unlock() + /* + ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) + defer cancel() + + isEncrypted := false + if c.dialContextFunc == nil { + netDialer := net.Dialer{} + c.dialContextFunc = netDialer.DialContext + + if c.useSSL { + tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig} + isEncrypted = true + c.dialContextFunc = tlsDialer.DialContext + } + } + connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) + if err != nil && c.fallbackPort != 0 { + // TODO: should we somehow log or append the previous error? + connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) + } + if err != nil { + return err + } + + client, err := smtp.NewClient(connection, c.host) + if err != nil { + return err + } + if client == nil { + return fmt.Errorf("SMTP client is nil") + } + c.smtpClient = client + + if c.logger != nil { + c.smtpClient.SetLogger(c.logger) + } + if c.useDebugLog { + c.smtpClient.SetDebugLog(true) + } + if c.logAuthData { + c.smtpClient.SetLogAuthData() + } + if err = c.smtpClient.Hello(c.helo); err != nil { + return err + } + + if err = c.tls(c.smtpClient, &isEncrypted); err != nil { + return err + } + + if err = c.auth(c.smtpClient, isEncrypted); err != nil { + return err + } + + */ + + return nil +} + +// SMTPClientFromDialWithContext is similar to DialWithContext but instead of storing the smtp.Client +// on the Client it will return the smtp.Client instead. +func (c *Client) SMTPClientFromDialWithContext(ctxDial context.Context) (*smtp.Client, error) { + c.mutex.RLock() + defer c.mutex.RUnlock() + + ctx, cancel := context.WithDeadline(ctxDial, time.Now().Add(c.connTimeout)) + defer cancel() + + isEncrypted := false + dialContextFunc := c.dialContextFunc + if c.dialContextFunc == nil { + netDialer := net.Dialer{} + dialContextFunc = netDialer.DialContext + if c.useSSL { + tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig} + isEncrypted = true + dialContextFunc = tlsDialer.DialContext + } + } + connection, err := dialContextFunc(ctx, "tcp", c.ServerAddr()) + if err != nil && c.fallbackPort != 0 { + // TODO: should we somehow log or append the previous error? + connection, err = dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) + } + if err != nil { + return nil, err + } client, err := smtp.NewClient(connection, c.host) if err != nil { - return err + return nil, err } if client == nil { - return fmt.Errorf("SMTP client is nil") + return nil, fmt.Errorf("SMTP client is nil") } - c.smtpClient = client if c.logger != nil { - c.smtpClient.SetLogger(c.logger) + client.SetLogger(c.logger) } if c.useDebugLog { - c.smtpClient.SetDebugLog(true) + client.SetDebugLog(true) } if c.logAuthData { - c.smtpClient.SetLogAuthData() + client.SetLogAuthData() } - if err = c.smtpClient.Hello(c.helo); err != nil { - return err + if err = client.Hello(c.helo); err != nil { + return nil, err } - if err = c.tls(); err != nil { - return err + if err = c.tls(client, &isEncrypted); err != nil { + return nil, err } - if err = c.auth(); err != nil { - return err + if err = c.auth(client, isEncrypted); err != nil { + return nil, err } - return nil + return client, nil } // Close terminates the connection to the SMTP server, returning an error if the disconnection @@ -999,10 +1065,14 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { // Returns: // - An error if the disconnection fails; otherwise, returns nil. func (c *Client) Close() error { - if c.smtpClient == nil || !c.smtpClient.HasConnection() { + return c.CloseWithSMTPClient(c.smtpClient) +} + +func (c *Client) CloseWithSMTPClient(client *smtp.Client) error { + if client == nil || !client.HasConnection() { return nil } - if err := c.smtpClient.Quit(); err != nil { + if err := client.Quit(); err != nil { return fmt.Errorf("failed to close SMTP client: %w", err) } @@ -1018,10 +1088,14 @@ func (c *Client) Close() error { // Returns: // - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil. func (c *Client) Reset() error { - if err := c.checkConn(); err != nil { + return c.ResetWithSMTPClient(c.smtpClient) +} + +func (c *Client) ResetWithSMTPClient(client *smtp.Client) error { + if err := c.checkConn(client); err != nil { return err } - if err := c.smtpClient.Reset(); err != nil { + if err := client.Reset(); err != nil { return fmt.Errorf("failed to send RSET to SMTP client: %w", err) } @@ -1061,19 +1135,20 @@ func (c *Client) DialAndSend(messages ...*Msg) error { // - An error if the connection fails, if sending the messages fails, or if closing the // connection fails; otherwise, returns nil. func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error { - c.sendMutex.Lock() - defer c.sendMutex.Unlock() - if err := c.DialWithContext(ctx); err != nil { + //c.sendMutex.Lock() + //defer c.sendMutex.Unlock() + client, err := c.SMTPClientFromDialWithContext(ctx) + if err != nil { return fmt.Errorf("dial failed: %w", err) } defer func() { - _ = c.Close() + _ = c.CloseWithSMTPClient(client) }() - if err := c.Send(messages...); err != nil { + if err := c.SendWithSMTPClient(client, messages...); err != nil { return fmt.Errorf("send failed: %w", err) } - if err := c.Close(); err != nil { + if err := c.CloseWithSMTPClient(client); err != nil { return fmt.Errorf("failed to close connection: %w", err) } return nil @@ -1098,16 +1173,17 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e // Returns: // - An error if the connection check fails, if no supported authentication method is found, // or if the authentication process fails. -func (c *Client) auth() error { +func (c *Client) auth(client *smtp.Client, isEnc bool) error { + var smtpAuth smtp.Auth if c.smtpAuth == nil && c.smtpAuthType != SMTPAuthNoAuth { - hasSMTPAuth, smtpAuthType := c.smtpClient.Extension("AUTH") + hasSMTPAuth, smtpAuthType := client.Extension("AUTH") if !hasSMTPAuth { return fmt.Errorf("server does not support SMTP AUTH") } authType := c.smtpAuthType if c.smtpAuthType == SMTPAuthAutoDiscover { - discoveredType, err := c.authTypeAutoDiscover(smtpAuthType) + discoveredType, err := c.authTypeAutoDiscover(smtpAuthType, isEnc) if err != nil { return err } @@ -1119,74 +1195,74 @@ func (c *Client) auth() error { if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) { return ErrPlainAuthNotSupported } - c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false) + smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false) case SMTPAuthPlainNoEnc: if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) { return ErrPlainAuthNotSupported } - c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true) + smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true) case SMTPAuthLogin: if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) { return ErrLoginAuthNotSupported } - c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false) + smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false) case SMTPAuthLoginNoEnc: if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) { return ErrLoginAuthNotSupported } - c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true) + smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true) case SMTPAuthCramMD5: if !strings.Contains(smtpAuthType, string(SMTPAuthCramMD5)) { return ErrCramMD5AuthNotSupported } - c.smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass) + smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass) case SMTPAuthXOAUTH2: if !strings.Contains(smtpAuthType, string(SMTPAuthXOAUTH2)) { return ErrXOauth2AuthNotSupported } - c.smtpAuth = smtp.XOAuth2Auth(c.user, c.pass) + smtpAuth = smtp.XOAuth2Auth(c.user, c.pass) case SMTPAuthSCRAMSHA1: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1)) { return ErrSCRAMSHA1AuthNotSupported } - c.smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass) + smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass) case SMTPAuthSCRAMSHA256: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256)) { return ErrSCRAMSHA256AuthNotSupported } - c.smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass) + smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass) case SMTPAuthSCRAMSHA1PLUS: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1PLUS)) { return ErrSCRAMSHA1PLUSAuthNotSupported } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { return err } - c.smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState) + smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState) case SMTPAuthSCRAMSHA256PLUS: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256PLUS)) { return ErrSCRAMSHA256PLUSAuthNotSupported } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { return err } - c.smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState) + smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState) default: return fmt.Errorf("unsupported SMTP AUTH type %q", c.smtpAuthType) } } - if c.smtpAuth != nil { - if err := c.smtpClient.Auth(c.smtpAuth); err != nil { + if smtpAuth != nil { + if err := client.Auth(smtpAuth); err != nil { return fmt.Errorf("SMTP AUTH failed: %w", err) } } return nil } -func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) { +func (c *Client) authTypeAutoDiscover(supported string, isEnc bool) (SMTPAuthType, error) { if supported == "" { return "", ErrNoSupportedAuthDiscovered } @@ -1194,7 +1270,7 @@ func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) { SMTPAuthSCRAMSHA256PLUS, SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1PLUS, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5, SMTPAuthPlain, SMTPAuthLogin, } - if !c.isEncrypted { + if !isEnc { preferList = []SMTPAuthType{SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5} } mechs := strings.Split(supported, " ") @@ -1231,13 +1307,13 @@ func sliceContains(slice []string, item string) bool { // // Returns: // - An error if any part of the sending process fails; otherwise, returns nil. -func (c *Client) sendSingleMsg(message *Msg) error { - c.mutex.Lock() - defer c.mutex.Unlock() - escSupport, _ := c.smtpClient.Extension("ENHANCEDSTATUSCODES") +func (c *Client) sendSingleMsg(client *smtp.Client, message *Msg) error { + c.mutex.RLock() + defer c.mutex.RUnlock() + escSupport, _ := client.Extension("ENHANCEDSTATUSCODES") if message.encoding == NoEncoding { - if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok { + if ok, _ := client.Extension("8BITMIME"); !ok { return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message} } } @@ -1260,16 +1336,16 @@ func (c *Client) sendSingleMsg(message *Msg) error { if c.requestDSN { if c.dsnReturnType != "" { - c.smtpClient.SetDSNMailReturnOption(string(c.dsnReturnType)) + client.SetDSNMailReturnOption(string(c.dsnReturnType)) } } - if err = c.smtpClient.Mail(from); err != nil { + if err = client.Mail(from); err != nil { retError := &SendError{ Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err), affectedMsg: message, errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), } - if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil { + if resetSendErr := client.Reset(); resetSendErr != nil { retError.errlist = append(retError.errlist, resetSendErr) } return retError @@ -1279,9 +1355,9 @@ func (c *Client) sendSingleMsg(message *Msg) error { rcptSendErr.errlist = make([]error, 0) rcptSendErr.rcpt = make([]string, 0) rcptNotifyOpt := strings.Join(c.dsnRcptNotifyType, ",") - c.smtpClient.SetDSNRcptNotifyOption(rcptNotifyOpt) + client.SetDSNRcptNotifyOption(rcptNotifyOpt) for _, rcpt := range rcpts { - if err = c.smtpClient.Rcpt(rcpt); err != nil { + if err = client.Rcpt(rcpt); err != nil { rcptSendErr.Reason = ErrSMTPRcptTo rcptSendErr.errlist = append(rcptSendErr.errlist, err) rcptSendErr.rcpt = append(rcptSendErr.rcpt, rcpt) @@ -1292,12 +1368,12 @@ func (c *Client) sendSingleMsg(message *Msg) error { } } if hasError { - if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil { + if resetSendErr := client.Reset(); resetSendErr != nil { rcptSendErr.errlist = append(rcptSendErr.errlist, resetSendErr) } return rcptSendErr } - writer, err := c.smtpClient.Data() + writer, err := client.Data() if err != nil { return &SendError{ Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err), @@ -1322,7 +1398,7 @@ func (c *Client) sendSingleMsg(message *Msg) error { } message.isDelivered = true - if err = c.Reset(); err != nil { + if err = c.ResetWithSMTPClient(client); err != nil { return &SendError{ Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err), affectedMsg: message, errcode: errorCode(err), @@ -1344,21 +1420,24 @@ func (c *Client) sendSingleMsg(message *Msg) error { // Returns: // - An error if there is no active connection, if the NOOP command fails, or if extending // the deadline fails; otherwise, returns nil. -func (c *Client) checkConn() error { - if c.smtpClient == nil { +func (c *Client) checkConn(client *smtp.Client) error { + if client == nil { return ErrNoActiveConnection } - if !c.smtpClient.HasConnection() { + if !client.HasConnection() { return ErrNoActiveConnection } - if !c.noNoop { - if err := c.smtpClient.Noop(); err != nil { + c.mutex.RLock() + noNoop := c.noNoop + c.mutex.RUnlock() + if !noNoop { + if err := client.Noop(); err != nil { return ErrNoActiveConnection } } - if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil { + if err := client.UpdateDeadline(c.connTimeout); err != nil { return ErrDeadlineExtendFailed } return nil @@ -1405,10 +1484,10 @@ func (c *Client) setDefaultHelo() error { // Returns: // - An error if there is no active connection, if STARTTLS is required but not supported, // or if there are issues during the TLS handshake; otherwise, returns nil. -func (c *Client) tls() error { +func (c *Client) tls(client *smtp.Client, isEnc *bool) error { if !c.useSSL && c.tlspolicy != NoTLS { hasStartTLS := false - extension, _ := c.smtpClient.Extension("STARTTLS") + extension, _ := client.Extension("STARTTLS") if c.tlspolicy == TLSMandatory { hasStartTLS = true if !extension { @@ -1422,21 +1501,21 @@ func (c *Client) tls() error { } } if hasStartTLS { - if err := c.smtpClient.StartTLS(c.tlsconfig); err != nil { + if err := client.StartTLS(c.tlsconfig); err != nil { return err } } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { switch { case errors.Is(err, smtp.ErrNonTLSConnection): - c.isEncrypted = false + *isEnc = false return nil default: return fmt.Errorf("failed to get TLS connection state: %w", err) } } - c.isEncrypted = tlsConnState.HandshakeComplete + *isEnc = tlsConnState.HandshakeComplete } return nil } diff --git a/client_120.go b/client_120.go index 67c5b5e..16e2d0b 100644 --- a/client_120.go +++ b/client_120.go @@ -9,6 +9,8 @@ package mail import ( "errors" + + "github.com/wneessen/go-mail/smtp" ) // Send attempts to send one or more Msg using the Client connection to the SMTP server. @@ -27,11 +29,15 @@ import ( // Returns: // - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil. func (c *Client) Send(messages ...*Msg) (returnErr error) { + return c.SendWithSMTPClient(c.smtpClient, messages...) +} + +func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) (returnErr error) { escSupport := false - if c.smtpClient != nil { - escSupport, _ = c.smtpClient.Extension("ENHANCEDSTATUSCODES") + if client != nil { + escSupport, _ = client.Extension("ENHANCEDSTATUSCODES") } - if err := c.checkConn(); err != nil { + if err := c.checkConn(client); err != nil { returnErr = &SendError{ Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err), errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), @@ -45,7 +51,7 @@ func (c *Client) Send(messages ...*Msg) (returnErr error) { }() for id, message := range messages { - if sendErr := c.sendSingleMsg(message); sendErr != nil { + if sendErr := c.sendSingleMsg(client, message); sendErr != nil { messages[id].sendError = sendErr errs = append(errs, sendErr) } From be4201b05a885c48ff1f0c416549f1f6a680b413 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 15:29:22 +0100 Subject: [PATCH 04/11] Refactor debug logging and logger settings in Client Separated debug logging and logger setting methods to include SMTP client parameter for better encapsulation. Removed commented-out code for cleaner and more manageable codebase. --- client.go | 82 ++++++++++++------------------------------------------- 1 file changed, 18 insertions(+), 64 deletions(-) diff --git a/client.go b/client.go index de6479e..637e299 100644 --- a/client.go +++ b/client.go @@ -812,9 +812,15 @@ func (c *Client) SetSSLPort(ssl bool, fallback bool) { // Parameters: // - val: A boolean value indicating whether to enable (true) or disable (false) debug logging. func (c *Client) SetDebugLog(val bool) { + c.SetDebugLogWithSMTPClient(c.smtpClient, val) +} + +func (c *Client) SetDebugLogWithSMTPClient(client *smtp.Client, val bool) { + c.mutex.Lock() + defer c.mutex.Unlock() c.useDebugLog = val - if c.smtpClient != nil { - c.smtpClient.SetDebugLog(val) + if client != nil { + client.SetDebugLog(val) } } @@ -827,9 +833,15 @@ func (c *Client) SetDebugLog(val bool) { // Parameters: // - logger: A logger that satisfies the log.Logger interface to be set for the Client. func (c *Client) SetLogger(logger log.Logger) { + c.SetLoggerWithSMTPClient(c.smtpClient, logger) +} + +func (c *Client) SetLoggerWithSMTPClient(client *smtp.Client, logger log.Logger) { + c.mutex.Lock() + defer c.mutex.Unlock() c.logger = logger - if c.smtpClient != nil { - c.smtpClient.SetLogger(logger) + if client != nil { + client.SetLogger(logger) } } @@ -935,62 +947,6 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { c.mutex.Lock() c.smtpClient = client c.mutex.Unlock() - /* - ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) - defer cancel() - - isEncrypted := false - if c.dialContextFunc == nil { - netDialer := net.Dialer{} - c.dialContextFunc = netDialer.DialContext - - if c.useSSL { - tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig} - isEncrypted = true - c.dialContextFunc = tlsDialer.DialContext - } - } - connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) - if err != nil && c.fallbackPort != 0 { - // TODO: should we somehow log or append the previous error? - connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) - } - if err != nil { - return err - } - - client, err := smtp.NewClient(connection, c.host) - if err != nil { - return err - } - if client == nil { - return fmt.Errorf("SMTP client is nil") - } - c.smtpClient = client - - if c.logger != nil { - c.smtpClient.SetLogger(c.logger) - } - if c.useDebugLog { - c.smtpClient.SetDebugLog(true) - } - if c.logAuthData { - c.smtpClient.SetLogAuthData() - } - if err = c.smtpClient.Hello(c.helo); err != nil { - return err - } - - if err = c.tls(c.smtpClient, &isEncrypted); err != nil { - return err - } - - if err = c.auth(c.smtpClient, isEncrypted); err != nil { - return err - } - - */ - return nil } @@ -1135,8 +1091,6 @@ func (c *Client) DialAndSend(messages ...*Msg) error { // - An error if the connection fails, if sending the messages fails, or if closing the // connection fails; otherwise, returns nil. func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error { - //c.sendMutex.Lock() - //defer c.sendMutex.Unlock() client, err := c.SMTPClientFromDialWithContext(ctx) if err != nil { return fmt.Errorf("dial failed: %w", err) @@ -1145,10 +1099,10 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e _ = c.CloseWithSMTPClient(client) }() - if err := c.SendWithSMTPClient(client, messages...); err != nil { + if err = c.SendWithSMTPClient(client, messages...); err != nil { return fmt.Errorf("send failed: %w", err) } - if err := c.CloseWithSMTPClient(client); err != nil { + if err = c.CloseWithSMTPClient(client); err != nil { return fmt.Errorf("failed to close connection: %w", err) } return nil From 3e504e6338495d0cdbfed401642b193909d5668a Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 15:29:31 +0100 Subject: [PATCH 05/11] Lock sendMutex to ensure thread safety in Send Added a mutex lock and unlock around the Send function to prevent concurrent access issues and ensure thread safety. This change helps avoid race conditions when multiple goroutines attempt to send messages simultaneously. --- client_120.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client_120.go b/client_120.go index 16e2d0b..bc157a9 100644 --- a/client_120.go +++ b/client_120.go @@ -29,6 +29,8 @@ import ( // Returns: // - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil. func (c *Client) Send(messages ...*Msg) (returnErr error) { + c.sendMutex.Lock() + defer c.sendMutex.Unlock() return c.SendWithSMTPClient(c.smtpClient, messages...) } From 3553b657697cc9fef3180b91121523d0750494c0 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 15:38:04 +0100 Subject: [PATCH 06/11] Add new Send function to client.go and remove duplicates The new Send function in client.go adds thread safety by using a mutex. This change also removes duplicate Send functions from client_119.go and client_120.go, consolidating the logic in one place for easier maintenance. --- client.go | 6 ++++++ client_119.go | 17 +++++++++++------ client_120.go | 5 ----- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index 637e299..4ee68f0 100644 --- a/client.go +++ b/client.go @@ -1108,6 +1108,12 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e return nil } +func (c *Client) Send(messages ...*Msg) (returnErr error) { + c.sendMutex.Lock() + defer c.sendMutex.Unlock() + return c.SendWithSMTPClient(c.smtpClient, messages...) +} + // auth attempts to authenticate the client using SMTP AUTH mechanisms. It checks the connection, // determines the supported authentication methods, and applies the appropriate authentication // type. An error is returned if authentication fails. diff --git a/client_119.go b/client_119.go index a28f747..fb8c982 100644 --- a/client_119.go +++ b/client_119.go @@ -7,7 +7,11 @@ package mail -import "errors" +import ( + "errors" + + "github.com/wneessen/go-mail/smtp" +) // Send attempts to send one or more Msg using the Client connection to the SMTP server. // If the Client has no active connection to the server, Send will fail with an error. For each @@ -26,12 +30,13 @@ import "errors" // Returns: // - An error that represents the sending result, which may include multiple SendErrors if // any occurred; otherwise, returns nil. -func (c *Client) Send(messages ...*Msg) error { + +func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) error { escSupport := false - if c.smtpClient != nil { - escSupport, _ = c.smtpClient.Extension("ENHANCEDSTATUSCODES") + if client != nil { + escSupport, _ = client.Extension("ENHANCEDSTATUSCODES") } - if err := c.checkConn(); err != nil { + if err := c.checkConn(client); err != nil { return &SendError{ Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err), errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), @@ -39,7 +44,7 @@ func (c *Client) Send(messages ...*Msg) error { } var errs []*SendError for id, message := range messages { - if sendErr := c.sendSingleMsg(message); sendErr != nil { + if sendErr := c.sendSingleMsg(client, message); sendErr != nil { messages[id].sendError = sendErr var msgSendErr *SendError diff --git a/client_120.go b/client_120.go index bc157a9..e9ce8dc 100644 --- a/client_120.go +++ b/client_120.go @@ -28,11 +28,6 @@ import ( // // Returns: // - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil. -func (c *Client) Send(messages ...*Msg) (returnErr error) { - c.sendMutex.Lock() - defer c.sendMutex.Unlock() - return c.SendWithSMTPClient(c.smtpClient, messages...) -} func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) (returnErr error) { escSupport := false From 4c107f4645526a2d2e76c88c323363308f48708b Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 15:55:45 +0100 Subject: [PATCH 07/11] Refactor Client methods to handle smtp.Client parameter Updated several Client methods to accept a smtp.Client pointer, allowing for more flexible and explicit SMTP client management. Added detailed parameter descriptions and extended error handling documentation. --- client.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++++-- client_119.go | 10 ++++---- client_120.go | 20 ++++++++------- 3 files changed, 82 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index 4ee68f0..662d988 100644 --- a/client.go +++ b/client.go @@ -803,7 +803,7 @@ func (c *Client) SetSSLPort(ssl bool, fallback bool) { } // SetDebugLog sets or overrides whether the Client is using debug logging. The debug logger will log incoming -// and outgoing communication between the Client and the server to os.Stderr. +// and outgoing communication between the client and the server to log.Logger that is defined on the Client. // // Note: The SMTP communication might include unencrypted authentication data, depending on whether you are using // SMTP authentication and the type of authentication mechanism. This could pose a data protection risk. Use @@ -815,6 +815,17 @@ func (c *Client) SetDebugLog(val bool) { c.SetDebugLogWithSMTPClient(c.smtpClient, val) } +// SetDebugLogWithSMTPClient sets or overrides whether the provided smtp.Client is using debug logging. +// The debug logger will log incoming and outgoing communication between the client and the server to +// log.Logger that is defined on the Client. +// +// Note: The SMTP communication might include unencrypted authentication data, depending on whether you are using +// SMTP authentication and the type of authentication mechanism. This could pose a data protection risk. Use +// debug logging with caution. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// - val: A boolean value indicating whether to enable (true) or disable (false) debug logging. func (c *Client) SetDebugLogWithSMTPClient(client *smtp.Client, val bool) { c.mutex.Lock() defer c.mutex.Unlock() @@ -836,6 +847,15 @@ func (c *Client) SetLogger(logger log.Logger) { c.SetLoggerWithSMTPClient(c.smtpClient, logger) } +// SetLoggerWithSMTPClient sets or overrides the custom logger currently used by the provided smtp.Client. +// The logger must satisfy the log.Logger interface and is only utilized when debug logging is enabled on +// the provided smtp.Client. +// +// By default, log.Stdlog is used if no custom logger is provided. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// - logger: A logger that satisfies the log.Logger interface to be set for the Client. func (c *Client) SetLoggerWithSMTPClient(client *smtp.Client, logger log.Logger) { c.mutex.Lock() defer c.mutex.Unlock() @@ -1024,6 +1044,19 @@ func (c *Client) Close() error { return c.CloseWithSMTPClient(c.smtpClient) } +// CloseWithSMTPClient terminates the connection of the provided smtp.Client to the SMTP server, +// returning an error if the disconnection fails. If the connection is already closed, this +// method is a no-op and disregards any error. +// +// This function checks if the smtp.Client connection is active. If not, it simply returns +// without any action. If the connection is active, it attempts to gracefully close the +// connection using the Quit method. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// +// Returns: +// - An error if the disconnection fails; otherwise, returns nil. func (c *Client) CloseWithSMTPClient(client *smtp.Client) error { if client == nil || !client.HasConnection() { return nil @@ -1042,11 +1075,24 @@ func (c *Client) CloseWithSMTPClient(client *smtp.Client) error { // the command fails, an error is returned. // // Returns: -// - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil. +// - An error if the connection check fails or if sending the RSET command fails; +// otherwise, returns nil. func (c *Client) Reset() error { return c.ResetWithSMTPClient(c.smtpClient) } +// ResetWithSMTPClient sends an SMTP RSET command to the provided smtp.Client, to reset +// the state of the current SMTP session. +// +// This method checks the connection to the SMTP server and, if the connection is valid, +// it sends an RSET command to reset the session state. If the connection is invalid or +// the command fails, an error is returned. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// +// Returns: +// - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil. func (c *Client) ResetWithSMTPClient(client *smtp.Client) error { if err := c.checkConn(client); err != nil { return err @@ -1108,6 +1154,24 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e return nil } +// Send attempts to send one or more Msg using the SMTP client that is assigned to the Client. +// If the Client has no active connection to the server, Send will fail with an error. For +// each of the provided Msg, it will associate a SendError with the Msg in case of a +// transmission or delivery error. +// +// This method first checks for an active connection to the SMTP server. If the connection is +// not valid, it returns a SendError. It then iterates over the provided messages, attempting +// to send each one. If an error occurs during sending, the method records the error and +// associates it with the corresponding Msg. If multiple errors are encountered, it aggregates +// them into a single SendError to be returned. +// +// Parameters: +// - client: A pointer to the smtp.Client that holds the connection to the SMTP server +// - messages: A variadic list of pointers to Msg objects to be sent. +// +// Returns: +// - An error that represents the sending result, which may include multiple SendErrors if +// any occurred; otherwise, returns nil. func (c *Client) Send(messages ...*Msg) (returnErr error) { c.sendMutex.Lock() defer c.sendMutex.Unlock() diff --git a/client_119.go b/client_119.go index fb8c982..96d18a7 100644 --- a/client_119.go +++ b/client_119.go @@ -13,10 +13,10 @@ import ( "github.com/wneessen/go-mail/smtp" ) -// Send attempts to send one or more Msg using the Client connection to the SMTP server. -// If the Client has no active connection to the server, Send will fail with an error. For each -// of the provided Msg, it will associate a SendError with the Msg in case of a transmission -// or delivery error. +// SendWithSMTPClient attempts to send one or more Msg using a provided smtp.Client with an +// established connection to the SMTP server. If the smtp.Client has no active connection to +// the server, SendWithSMTPClient will fail with an error. For each of the provided Msg, it +// will associate a SendError with the Msg in case of a transmission or delivery error. // // This method first checks for an active connection to the SMTP server. If the connection is // not valid, it returns a SendError. It then iterates over the provided messages, attempting @@ -25,12 +25,12 @@ import ( // them into a single SendError to be returned. // // Parameters: +// - client: A pointer to the smtp.Client that holds the connection to the SMTP server // - messages: A variadic list of pointers to Msg objects to be sent. // // Returns: // - An error that represents the sending result, which may include multiple SendErrors if // any occurred; otherwise, returns nil. - func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) error { escSupport := false if client != nil { diff --git a/client_120.go b/client_120.go index e9ce8dc..622e149 100644 --- a/client_120.go +++ b/client_120.go @@ -13,22 +13,24 @@ import ( "github.com/wneessen/go-mail/smtp" ) -// Send attempts to send one or more Msg using the Client connection to the SMTP server. -// If the Client has no active connection to the server, Send will fail with an error. For each -// of the provided Msg, it will associate a SendError with the Msg in case of a transmission -// or delivery error. +// SendWithSMTPClient attempts to send one or more Msg using a provided smtp.Client with an +// established connection to the SMTP server. If the smtp.Client has no active connection to +// the server, SendWithSMTPClient will fail with an error. For each of the provided Msg, it +// will associate a SendError with the Msg in case of a transmission or delivery error. // // This method first checks for an active connection to the SMTP server. If the connection is -// not valid, it returns an error wrapped in a SendError. It then iterates over the provided -// messages, attempting to send each one. If an error occurs during sending, the method records -// the error and associates it with the corresponding Msg. +// not valid, it returns a SendError. It then iterates over the provided messages, attempting +// to send each one. If an error occurs during sending, the method records the error and +// associates it with the corresponding Msg. If multiple errors are encountered, it aggregates +// them into a single SendError to be returned. // // Parameters: +// - client: A pointer to the smtp.Client that holds the connection to the SMTP server // - messages: A variadic list of pointers to Msg objects to be sent. // // Returns: -// - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil. - +// - An error that represents the sending result, which may include multiple SendErrors if +// any occurred; otherwise, returns nil. func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) (returnErr error) { escSupport := false if client != nil { From 7aba5212c47142eeafdfec1e6bf6e484e7a017c5 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 16:06:28 +0100 Subject: [PATCH 08/11] Rename method and improve context handling Renamed `SMTPClientFromDialWithContext` to `DialToSMTPClientWithContext` for clarity and consistency. Updated method parameters and documentation to standardize context usage across the codebase. --- client.go | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index 662d988..98aea82 100644 --- a/client.go +++ b/client.go @@ -955,12 +955,12 @@ func (c *Client) SetLogAuthData(logAuth bool) { // SMTP server. // // Parameters: -// - dialCtx: The context.Context used to control the connection timeout and cancellation. +// - ctxDial: The context.Context used to control the connection timeout and cancellation. // // Returns: // - An error if the connection to the SMTP server fails or any subsequent command fails. -func (c *Client) DialWithContext(dialCtx context.Context) error { - client, err := c.SMTPClientFromDialWithContext(dialCtx) +func (c *Client) DialWithContext(ctxDial context.Context) error { + client, err := c.DialToSMTPClientWithContext(ctxDial) if err != nil { return err } @@ -970,9 +970,23 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { return nil } -// SMTPClientFromDialWithContext is similar to DialWithContext but instead of storing the smtp.Client -// on the Client it will return the smtp.Client instead. -func (c *Client) SMTPClientFromDialWithContext(ctxDial context.Context) (*smtp.Client, error) { +// DialToSMTPClientWithContext establishes and configures a smtp.Client connection using +// the provided context. +// +// This function uses the provided context to manage the connection deadline and cancellation. +// It dials the SMTP server using the Client's configured DialContextFunc or a default dialer. +// If SSL is enabled, it uses a TLS connection. After successfully connecting, it initializes +// an smtp.Client, sends the HELO/EHLO command, and optionally performs STARTTLS and SMTP AUTH +// based on the Client's configuration. Debug and authentication logging are enabled if +// configured. +// +// Parameters: +// - ctxDial: The context used to control the connection timeout and cancellation. +// +// Returns: +// - A pointer to the initialized smtp.Client. +// - An error if the connection fails, the smtp.Client cannot be created, or any subsequent commands fail. +func (c *Client) DialToSMTPClientWithContext(ctxDial context.Context) (*smtp.Client, error) { c.mutex.RLock() defer c.mutex.RUnlock() @@ -1137,7 +1151,7 @@ func (c *Client) DialAndSend(messages ...*Msg) error { // - An error if the connection fails, if sending the messages fails, or if closing the // connection fails; otherwise, returns nil. func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error { - client, err := c.SMTPClientFromDialWithContext(ctx) + client, err := c.DialToSMTPClientWithContext(ctx) if err != nil { return fmt.Errorf("dial failed: %w", err) } From b4d3b165edd1ab8069637d7916c6d915edac93aa Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 16:21:49 +0100 Subject: [PATCH 09/11] Add test for dialing SMTP client with context Introduce a new test, `TestClient_DialToSMTPClientWithContext`, to verify client behavior when establishing and closing an SMTP connection with context handling. Also added a sub-test to simulate and confirm failure scenarios during SMTP connection establishment. --- client_test.go | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/client_test.go b/client_test.go index 36fb423..10dd898 100644 --- a/client_test.go +++ b/client_test.go @@ -2742,6 +2742,83 @@ func TestClient_Send(t *testing.T) { }) } +func TestClient_DialToSMTPClientWithContext(t *testing.T) { + t.Run("establish a new client connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + smtpClient, err := client.DialToSMTPClientWithContext(ctxDial) + if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + t.Skip("failed to connect to the test server due to timeout") + } + t.Fatalf("failed to connect to test server: %s", err) + } + t.Cleanup(func() { + if err := client.CloseWithSMTPClient(smtpClient); err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + t.Skip("failed to close the test server connection due to timeout") + } + t.Errorf("failed to close client: %s", err) + } + }) + if smtpClient == nil { + t.Fatal("expected SMTP client, got nil") + } + if !smtpClient.HasConnection() { + t.Fatal("expected connection on smtp client") + } + if ok, _ := smtpClient.Extension("DSN"); !ok { + t.Error("expected DSN extension but it was not found") + } + }) + t.Run("dial to SMTP server fails on first client writeFile", func(t *testing.T) { + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + failReadWriteSeekCloser{}, + failReadWriteSeekCloser{}, + } + + ctxDial, cancelDial := context.WithTimeout(context.Background(), time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithDialContextFunc(getFakeDialFunc(fake))) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + _, err = client.DialToSMTPClientWithContext(ctxDial) + if err == nil { + t.Fatal("expected connection to fake to fail") + } + }) + +} + func TestClient_sendSingleMsg(t *testing.T) { t.Run("connect and send email", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) From c61aad4fcb8cfbbfa14c4b9225d5c93f7fb07464 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 16:22:38 +0100 Subject: [PATCH 10/11] Remove extra blank line in client_test.go This change removes an unnecessary blank line at the end of the TestClient_sendSingleMsg function in the client_test.go file. Keeping the code clean and properly formatted improves readability and maintainability. --- client_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client_test.go b/client_test.go index 10dd898..195a797 100644 --- a/client_test.go +++ b/client_test.go @@ -2816,7 +2816,6 @@ func TestClient_DialToSMTPClientWithContext(t *testing.T) { t.Fatal("expected connection to fake to fail") } }) - } func TestClient_sendSingleMsg(t *testing.T) { From 6ebc60d1df74e2fb764f6dec0221ee5152cb4f8a Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 16:22:53 +0100 Subject: [PATCH 11/11] Remove redundant nil check for SMTP client The removed check for a nil SMTP client was redundant because the previous error handling already covers this case. Streamlining this part of the code improves readability and reduces unnecessary checks. --- client.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/client.go b/client.go index 98aea82..028e277 100644 --- a/client.go +++ b/client.go @@ -1017,9 +1017,6 @@ func (c *Client) DialToSMTPClientWithContext(ctxDial context.Context) (*smtp.Cli if err != nil { return nil, err } - if client == nil { - return nil, fmt.Errorf("SMTP client is nil") - } if c.logger != nil { client.SetLogger(c.logger)