From 45776c052f417971339f6dc7aa88ecca11d78ab8 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 22 Nov 2024 14:58:05 +0100 Subject: [PATCH] Refactor client error handling and add concurrent send tests Updated tests to correctly assert the absence of an SMTP client on failure. Added concurrent sending tests for DialAndSendWithContext to improve test coverage and reliability. Also, refined the `AutoDiscover` and other client methods to ensure proper parameter use. --- client_test.go | 112 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 76 insertions(+), 36 deletions(-) diff --git a/client_test.go b/client_test.go index f6d5161..36fb423 100644 --- a/client_test.go +++ b/client_test.go @@ -1772,11 +1772,8 @@ func TestClient_DialWithContext(t *testing.T) { t.Errorf("failed to close the client: %s", err) } }) - if client.smtpClient == nil { - t.Errorf("client with invalid HELO should still have a smtp client, got nil") - } - if !client.smtpClient.HasConnection() { - t.Errorf("client with invalid HELO should still have a smtp client connection, got nil") + if client.smtpClient != nil { + t.Error("client with invalid HELO should not have a smtp client") } }) t.Run("fail on base port and fallback", func(t *testing.T) { @@ -1825,11 +1822,8 @@ func TestClient_DialWithContext(t *testing.T) { if err = client.DialWithContext(ctxDial); err == nil { t.Fatalf("connection was supposed to fail, but didn't") } - if client.smtpClient == nil { - t.Fatalf("client has no smtp client") - } - if !client.smtpClient.HasConnection() { - t.Errorf("client has no connection") + if client.smtpClient != nil { + t.Fatalf("client is not supposed to have a smtp client") } }) t.Run("connect with failing auth", func(t *testing.T) { @@ -2297,6 +2291,7 @@ func TestClient_DialAndSendWithContext(t *testing.T) { t.Errorf("client was supposed to fail on dial") } }) + // https://github.com/wneessen/go-mail/issues/380 t.Run("concurrent sending via DialAndSendWithContext", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2336,6 +2331,44 @@ func TestClient_DialAndSendWithContext(t *testing.T) { } wg.Wait() }) + // https://github.com/wneessen/go-mail/issues/385 + t.Run("concurrent sending via DialAndSendWithContext on receiver func", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + sender := testSender{client} + + ctxDial := context.Background() + wg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + wg.Add(1) + msg := testMessage(t) + go func() { + defer wg.Done() + if goroutineErr := sender.Send(ctxDial, msg); goroutineErr != nil { + t.Errorf("failed to send message: %s", goroutineErr) + } + }() + } + wg.Wait() + }) } func TestClient_auth(t *testing.T) { @@ -2574,8 +2607,8 @@ func TestClient_authTypeAutoDiscover(t *testing.T) { } for _, tt := range tests { t.Run("AutoDiscover selects the strongest auth type: "+string(tt.expect), func(t *testing.T) { - client := &Client{smtpAuthType: SMTPAuthAutoDiscover, isEncrypted: tt.tls} - authType, err := client.authTypeAutoDiscover(tt.supported) + client := &Client{smtpAuthType: SMTPAuthAutoDiscover} + authType, err := client.authTypeAutoDiscover(tt.supported, tt.tls) if err != nil && !tt.shouldFail { t.Fatalf("failed to auto discover auth type: %s", err) } @@ -2748,7 +2781,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err != nil { + if err = client.sendSingleMsg(client.smtpClient, message); err != nil { t.Errorf("failed to send message: %s", err) } }) @@ -2791,7 +2824,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } }) @@ -2836,7 +2869,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2886,7 +2919,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2936,7 +2969,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2986,7 +3019,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err != nil { + if err = client.sendSingleMsg(client.smtpClient, message); err != nil { t.Errorf("failed to send message: %s", err) } }) @@ -3029,7 +3062,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3080,7 +3113,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3131,7 +3164,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3181,7 +3214,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3231,7 +3264,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3281,7 +3314,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Error("expected mail delivery to fail") } var sendErr *SendError @@ -3334,7 +3367,7 @@ func TestClient_checkConn(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.checkConn(); err != nil { + if err = client.checkConn(client.smtpClient); err != nil { t.Errorf("failed to check connection: %s", err) } }) @@ -3375,7 +3408,7 @@ func TestClient_checkConn(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.checkConn(); err == nil { + if err = client.checkConn(client.smtpClient); err == nil { t.Errorf("client should have failed on connection check") } if !errors.Is(err, ErrNoActiveConnection) { @@ -3387,7 +3420,7 @@ func TestClient_checkConn(t *testing.T) { if err != nil { t.Fatalf("failed to create new client: %s", err) } - if err = client.checkConn(); err == nil { + if err = client.checkConn(client.smtpClient); err == nil { t.Errorf("client should have failed on connection check") } if !errors.Is(err, ErrNoActiveConnection) { @@ -3611,24 +3644,20 @@ func TestClient_XOAuth2OnFaker(t *testing.T) { } if err = c.DialWithContext(context.Background()); err == nil { t.Fatal("expected dial error got nil") - } else { - if !errors.Is(err, ErrXOauth2AuthNotSupported) { - t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) - } + } + if !errors.Is(err, ErrXOauth2AuthNotSupported) { + t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) } if err = c.Close(); err != nil { t.Fatalf("disconnect from test server failed: %v", err) } client := strings.Split(wrote.String(), "\r\n") - if len(client) != 3 { - t.Fatalf("unexpected number of client requests got %d; want 3", len(client)) + if len(client) != 2 { + t.Fatalf("unexpected number of client requests got %d; want 2", len(client)) } if !strings.HasPrefix(client[0], "EHLO") { t.Fatalf("expected EHLO, got %q", client[0]) } - if client[1] != "QUIT" { - t.Fatalf("expected QUIT, got %q", client[3]) - } }) } @@ -3652,6 +3681,17 @@ func (f faker) SetDeadline(time.Time) error { return nil } func (f faker) SetReadDeadline(time.Time) error { return nil } func (f faker) SetWriteDeadline(time.Time) error { return nil } +type testSender struct { + client *Client +} + +func (t *testSender) Send(ctx context.Context, m *Msg) error { + if err := t.client.DialAndSendWithContext(ctx, m); err != nil { + return fmt.Errorf("failed to dial and send mail: %w", err) + } + return nil +} + // parseJSONLog parses a JSON encoded log from the provided buffer and returns a slice of logLine structs. // In case of a decode error, it reports the error to the testing framework. func parseJSONLog(t *testing.T, buf *bytes.Buffer) logData {