From 572751ac10678c1ce09b4a10e5f496dbcfe1181a Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Wed, 23 Oct 2024 18:20:52 +0200 Subject: [PATCH] Add test for invalid HELO handling in SMTP client Introduce a new test case to ensure the SMTP client fails gracefully when an invalid HELO command is used. This includes validating error handling and maintaining the client connection integrity. Also, optimize EHLO/HELO command handling by enhancing syntax checking and error response generation. --- client_test.go | 69 +++++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/client_test.go b/client_test.go index 41ed1e6..ff48c23 100644 --- a/client_test.go +++ b/client_test.go @@ -1664,6 +1664,31 @@ func TestClient_DialWithContext(t *testing.T) { t.Errorf("client with invalid host should not have a smtp client") } }) + t.Run("fail on invalid HELO", func(t *testing.T) { + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + client.helo = "" + + if err = client.DialWithContext(ctxDial); err == nil { + t.Errorf("client with invalid HELO should fail") + } + t.Cleanup(func() { + if err := client.Close(); err != nil { + t.Errorf("failed to close the client: %s", err) + } + }) + if client.smtpClient == nil { + t.Errorf("client with invalid HELO should still have a smtp client, got nil") + } + if !client.smtpClient.HasConnection() { + t.Errorf("client with invalid HELO should still have a smtp client connection, got nil") + } + }) t.Run("fail on base port and fallback", func(t *testing.T) { ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) t.Cleanup(cancelDial) @@ -1760,20 +1785,6 @@ func TestClient_DialWithContext(t *testing.T) { -// TestClient_DialWithContextInvalidHost tests the DialWithContext method with intentional breaking -// for the Client object -func TestClient_DialWithContextInvalidHost(t *testing.T) { - c, err := getTestConnection(true) - if err != nil { - t.Skipf("failed to create test client: %s. Skipping tests", err) - } - c.host = "invalid.addr" - ctx := context.Background() - if err = c.DialWithContext(ctx); err == nil { - t.Errorf("dial succeeded but was supposed to fail") - return - } -} // TestClient_DialWithContextInvalidHELO tests the DialWithContext method with intentional breaking // for the Client object @@ -3701,24 +3712,8 @@ func handleTestServerConnection(connection net.Conn, props *serverProps) { return } - data, err := reader.ReadString('\n') - if err != nil { - return - } - if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") { - fmt.Printf("expected EHLO, got %q", data) - return - } - if props.FailOnHelo { - _ = writeLine("500 5.5.2 Error: fail on HELO") - return - } - if err = writeLine("250-localhost.localdomain\r\n" + props.FeatureSet); err != nil { - return - } - for { - data, err = reader.ReadString('\n') + data, err := reader.ReadString('\n') if err != nil { break } @@ -3727,6 +3722,18 @@ func handleTestServerConnection(connection net.Conn, props *serverProps) { var datastring string data = strings.TrimSpace(data) switch { + case strings.HasPrefix(data, "EHLO"), strings.HasPrefix(data, "HELO"): + if len(strings.Split(data, " ")) != 2 { + _ = writeLine("501 Syntax: EHLO hostname") + return + } + if props.FailOnHelo { + _ = writeLine("500 5.5.2 Error: fail on HELO") + return + } + if err = writeLine("250-localhost.localdomain\r\n" + props.FeatureSet); err != nil { + return + } case strings.HasPrefix(data, "MAIL FROM:"): from := strings.TrimPrefix(data, "MAIL FROM:") from = strings.ReplaceAll(from, "BODY=8BITMIME", "")