From e9c7bdbb4e963c5e489a020a77f7e6c2e64b8abc Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 12:54:10 +0100 Subject: [PATCH] Refactor TLS config initialization in tests Replace repetitive TLS configuration code with a reusable `getTLSConfig` helper function for consistency and maintainability. Additionally, update port configuration and add new tests for mail data transmission. --- smtp/smtp_test.go | 120 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 100 insertions(+), 20 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 66d3ee5..8613dd1 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -22,6 +22,7 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/tls" + "crypto/x509" "encoding/base64" "errors" "fmt" @@ -42,7 +43,7 @@ const ( // TestServerAddr is the address the simple SMTP test server listens on TestServerAddr = "127.0.0.1" // TestServerPortBase is the base port for the simple SMTP test server - TestServerPortBase = 12025 + TestServerPortBase = 30025 ) // PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances. @@ -1087,13 +1088,8 @@ func TestScramAuth(t *testing.T) { var client *Client switch tt.tls { case true: - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig) + tlsConfig := getTLSConfig(t) + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig) if err != nil { t.Fatalf("failed to dial TLS server: %v", err) } @@ -1161,13 +1157,8 @@ func TestScramAuth(t *testing.T) { var client *Client switch tt.tls { case true: - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig) + tlsConfig := getTLSConfig(t) + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig) if err != nil { t.Fatalf("failed to dial TLS server: %v", err) } @@ -1719,7 +1710,7 @@ func TestClient_StartTLS(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - tlsConfig := &tls.Config{InsecureSkipVerify: true} + tlsConfig := getTLSConfig(t) if err = client.StartTLS(tlsConfig); err != nil { t.Errorf("failed to initialize STARTTLS session: %s", err) } @@ -1753,7 +1744,7 @@ func TestClient_StartTLS(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - tlsConfig := &tls.Config{InsecureSkipVerify: true} + tlsConfig := getTLSConfig(t) if err = client.StartTLS(tlsConfig); err == nil { t.Error("STARTTLS should fail on EHLO") } @@ -1786,7 +1777,7 @@ func TestClient_StartTLS(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - tlsConfig := &tls.Config{InsecureSkipVerify: true} + tlsConfig := getTLSConfig(t) if err = client.StartTLS(tlsConfig); err == nil { t.Error("STARTTLS should fail for server not supporting it") } @@ -1821,7 +1812,8 @@ func TestClient_TLSConnectionState(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - tlsConfig := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12} + tlsConfig := getTLSConfig(t) + tlsConfig.MinVersion = tls.VersionTLS12 if err = client.StartTLS(tlsConfig); err != nil { t.Errorf("failed to initialize STARTTLS session: %s", err) } @@ -2472,6 +2464,79 @@ func TestClient_Rcpt(t *testing.T) { }) } +func TestClient_Data(t *testing.T) { + t.Run("normal mail data transmission succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + writer, err := client.Data() + if err != nil { + t.Fatalf("failed to create data writer: %s", err) + } + t.Cleanup(func() { + if err = writer.Close(); err != nil { + t.Errorf("failed to close data writer: %s", err) + } + }) + if _, err = writer.Write([]byte("test message")); err != nil { + t.Errorf("failed to write data to test server: %s", err) + } + }) + t.Run("mail data transmission fails on DATA command", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnDataInit: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if _, err = client.Data(); err == nil { + t.Error("expected data writer to fail") + } + }) +} + /* func TestBasic(t *testing.T) { server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") @@ -4251,7 +4316,7 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server break } writeLine("220 Ready to start TLS") - tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}} + tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}, ServerName: "example.com"} connection = tls.Server(connection, tlsConfig) props.IsTLS = true handleTestServerConnection(connection, t, props) @@ -4444,3 +4509,18 @@ type failWriter struct{} func (w *failWriter) Write([]byte) (int, error) { return 0, errors.New("broken writer") } + +func getTLSConfig(t *testing.T) *tls.Config { + t.Helper() + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Fatalf("unable to load host certifcate: %s", err) + } + testRootCAs := x509.NewCertPool() + testRootCAs.AppendCertsFromPEM(localhostCert) + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: testRootCAs, + ServerName: "example.com", + } +}