diff --git a/client_test.go b/client_test.go index 41ed1e6..ff48c23 100644 --- a/client_test.go +++ b/client_test.go @@ -1664,6 +1664,31 @@ func TestClient_DialWithContext(t *testing.T) { t.Errorf("client with invalid host should not have a smtp client") } }) + t.Run("fail on invalid HELO", func(t *testing.T) { + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + client.helo = "" + + if err = client.DialWithContext(ctxDial); err == nil { + t.Errorf("client with invalid HELO should fail") + } + t.Cleanup(func() { + if err := client.Close(); err != nil { + t.Errorf("failed to close the client: %s", err) + } + }) + if client.smtpClient == nil { + t.Errorf("client with invalid HELO should still have a smtp client, got nil") + } + if !client.smtpClient.HasConnection() { + t.Errorf("client with invalid HELO should still have a smtp client connection, got nil") + } + }) t.Run("fail on base port and fallback", func(t *testing.T) { ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) t.Cleanup(cancelDial) @@ -1760,20 +1785,6 @@ func TestClient_DialWithContext(t *testing.T) { -// TestClient_DialWithContextInvalidHost tests the DialWithContext method with intentional breaking -// for the Client object -func TestClient_DialWithContextInvalidHost(t *testing.T) { - c, err := getTestConnection(true) - if err != nil { - t.Skipf("failed to create test client: %s. Skipping tests", err) - } - c.host = "invalid.addr" - ctx := context.Background() - if err = c.DialWithContext(ctx); err == nil { - t.Errorf("dial succeeded but was supposed to fail") - return - } -} // TestClient_DialWithContextInvalidHELO tests the DialWithContext method with intentional breaking // for the Client object @@ -3701,24 +3712,8 @@ func handleTestServerConnection(connection net.Conn, props *serverProps) { return } - data, err := reader.ReadString('\n') - if err != nil { - return - } - if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") { - fmt.Printf("expected EHLO, got %q", data) - return - } - if props.FailOnHelo { - _ = writeLine("500 5.5.2 Error: fail on HELO") - return - } - if err = writeLine("250-localhost.localdomain\r\n" + props.FeatureSet); err != nil { - return - } - for { - data, err = reader.ReadString('\n') + data, err := reader.ReadString('\n') if err != nil { break } @@ -3727,6 +3722,18 @@ func handleTestServerConnection(connection net.Conn, props *serverProps) { var datastring string data = strings.TrimSpace(data) switch { + case strings.HasPrefix(data, "EHLO"), strings.HasPrefix(data, "HELO"): + if len(strings.Split(data, " ")) != 2 { + _ = writeLine("501 Syntax: EHLO hostname") + return + } + if props.FailOnHelo { + _ = writeLine("500 5.5.2 Error: fail on HELO") + return + } + if err = writeLine("250-localhost.localdomain\r\n" + props.FeatureSet); err != nil { + return + } case strings.HasPrefix(data, "MAIL FROM:"): from := strings.TrimPrefix(data, "MAIL FROM:") from = strings.ReplaceAll(from, "BODY=8BITMIME", "")