diff --git a/client_test.go b/client_test.go index c009acf..41ed1e6 100644 --- a/client_test.go +++ b/client_test.go @@ -1647,6 +1647,23 @@ func TestClient_DialWithContext(t *testing.T) { t.Fatalf("client has no connection") } }) + t.Run("fail on invalid host", func(t *testing.T) { + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + client.host = "invalid.addr" + + if err = client.DialWithContext(ctxDial); err == nil { + t.Errorf("client with invalid host should fail") + } + if client.smtpClient != nil { + t.Errorf("client with invalid host should not have a smtp client") + } + }) t.Run("fail on base port and fallback", func(t *testing.T) { ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) t.Cleanup(cancelDial) @@ -1701,119 +1718,48 @@ func TestClient_DialWithContext(t *testing.T) { t.Errorf("logAuthData not working, no authentication info found in logs") } }) + t.Run("connect should fail on HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + failServerPort := int(TestServerPortBase + PortAdder.Load()) + failFeatureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, &serverProps{ + FailOnHelo: true, + FeatureSet: failFeatureSet, + ListenPort: failServerPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithPort(failServerPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + if err = client.DialWithContext(ctxDial); err == nil { + t.Fatalf("connection was supposed to fail, but didn't") + } + if client.smtpClient == nil { + t.Fatalf("client has no smtp client") + } + if !client.smtpClient.HasConnection() { + t.Errorf("client has no connection") + } + }) + // TODO: Implement tests for TLS/SSL and custom DialCtxFunc } /* -// TestClient_DialWithContext tests the DialWithContext method for the Client object -func TestClient_DialWithContext(t *testing.T) { - c, err := getTestConnection(true) - if err != nil { - t.Skipf("failed to create test client: %s. Skipping tests", err) - } - ctx := context.Background() - if err = c.DialWithContext(ctx); err != nil { - t.Errorf("failed to dial with context: %s", err) - return - } - if c.smtpClient == nil { - t.Errorf("DialWithContext didn't fail but no SMTP client found.") - return - } - if !c.smtpClient.HasConnection() { - t.Errorf("DialWithContext didn't fail but no connection found.") - } - if err := c.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } -} - -// TestClient_DialWithContext_Fallback tests the Client.DialWithContext method with the fallback -// port functionality -func TestClient_DialWithContext_Fallback(t *testing.T) { - c, err := getTestConnectionNoTestPort(true) - if err != nil { - t.Skipf("failed to create test client: %s. Skipping tests", err) - } - c.SetTLSPortPolicy(TLSOpportunistic) - c.port = 999 - ctx := context.Background() - if err = c.DialWithContext(ctx); err != nil { - t.Errorf("failed to dial with context: %s", err) - return - } - if c.smtpClient == nil { - t.Errorf("DialWithContext didn't fail but no SMTP client found.") - return - } - if !c.smtpClient.HasConnection() { - t.Errorf("DialWithContext didn't fail but no connection found.") - } - if err = c.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - - c.port = 999 - c.fallbackPort = 999 - if err = c.DialWithContext(ctx); err == nil { - t.Error("dial with context was supposed to fail, but didn't") - return - } -} - -// TestClient_DialWithContext_Debug tests the DialWithContext method for the Client object with debug -// logging enabled on the SMTP client -func TestClient_DialWithContext_Debug(t *testing.T) { - c, err := getTestClient(true) - if err != nil { - t.Skipf("failed to create test client: %s. Skipping tests", err) - } - ctx := context.Background() - if err = c.DialWithContext(ctx); err != nil { - t.Errorf("failed to dial with context: %s", err) - return - } - if c.smtpClient == nil { - t.Errorf("DialWithContext didn't fail but no SMTP client found.") - return - } - if !c.smtpClient.HasConnection() { - t.Errorf("DialWithContext didn't fail but no connection found.") - } - c.SetDebugLog(true) - if err = c.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } -} - -// TestClient_DialWithContext_Debug_custom tests the DialWithContext method for the Client -// object with debug logging enabled and a custom logger on the SMTP client -func TestClient_DialWithContext_Debug_custom(t *testing.T) { - c, err := getTestClient(true) - if err != nil { - t.Skipf("failed to create test client: %s. Skipping tests", err) - } - ctx := context.Background() - if err = c.DialWithContext(ctx); err != nil { - t.Errorf("failed to dial with context: %s", err) - return - } - if c.smtpClient == nil { - t.Errorf("DialWithContext didn't fail but no SMTP client found.") - return - } - if !c.smtpClient.HasConnection() { - t.Errorf("DialWithContext didn't fail but no connection found.") - } - c.SetDebugLog(true) - c.SetLogger(log.New(os.Stderr, log.LevelDebug)) - if err = c.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } -} - // TestClient_DialWithContextInvalidHost tests the DialWithContext method with intentional breaking // for the Client object func TestClient_DialWithContextInvalidHost(t *testing.T) { @@ -3685,8 +3631,9 @@ func parseJSONLog(t *testing.T, buf *bytes.Buffer) logData { } type serverProps struct { - FailOnReset bool + FailOnHelo bool FailOnQuit bool + FailOnReset bool FeatureSet string ListenPort int } @@ -3762,6 +3709,10 @@ func handleTestServerConnection(connection net.Conn, props *serverProps) { fmt.Printf("expected EHLO, got %q", data) return } + if props.FailOnHelo { + _ = writeLine("500 5.5.2 Error: fail on HELO") + return + } if err = writeLine("250-localhost.localdomain\r\n" + props.FeatureSet); err != nil { return }