Refactor TLS config initialization in tests

Replace repetitive TLS configuration code with a reusable `getTLSConfig` helper function for consistency and maintainability. Additionally, update port configuration and add new tests for mail data transmission.
This commit is contained in:
Winni Neessen 2024-11-11 12:54:10 +01:00
parent 75bfdd2855
commit e9c7bdbb4e
Signed by: wneessen
GPG key ID: 385AC9889632126E

View file

@ -22,6 +22,7 @@ import (
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
@ -42,7 +43,7 @@ const (
// TestServerAddr is the address the simple SMTP test server listens on // TestServerAddr is the address the simple SMTP test server listens on
TestServerAddr = "127.0.0.1" TestServerAddr = "127.0.0.1"
// TestServerPortBase is the base port for the simple SMTP test server // TestServerPortBase is the base port for the simple SMTP test server
TestServerPortBase = 12025 TestServerPortBase = 30025
) )
// PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances. // PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances.
@ -1087,13 +1088,8 @@ func TestScramAuth(t *testing.T) {
var client *Client var client *Client
switch tt.tls { switch tt.tls {
case true: case true:
cert, err := tls.X509KeyPair(localhostCert, localhostKey) tlsConfig := getTLSConfig(t)
if err != nil { conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig)
fmt.Printf("error creating TLS cert: %s", err)
return
}
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig)
if err != nil { if err != nil {
t.Fatalf("failed to dial TLS server: %v", err) t.Fatalf("failed to dial TLS server: %v", err)
} }
@ -1161,13 +1157,8 @@ func TestScramAuth(t *testing.T) {
var client *Client var client *Client
switch tt.tls { switch tt.tls {
case true: case true:
cert, err := tls.X509KeyPair(localhostCert, localhostKey) tlsConfig := getTLSConfig(t)
if err != nil { conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig)
fmt.Printf("error creating TLS cert: %s", err)
return
}
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig)
if err != nil { if err != nil {
t.Fatalf("failed to dial TLS server: %v", err) t.Fatalf("failed to dial TLS server: %v", err)
} }
@ -1719,7 +1710,7 @@ func TestClient_StartTLS(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
}) })
tlsConfig := &tls.Config{InsecureSkipVerify: true} tlsConfig := getTLSConfig(t)
if err = client.StartTLS(tlsConfig); err != nil { if err = client.StartTLS(tlsConfig); err != nil {
t.Errorf("failed to initialize STARTTLS session: %s", err) t.Errorf("failed to initialize STARTTLS session: %s", err)
} }
@ -1753,7 +1744,7 @@ func TestClient_StartTLS(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
}) })
tlsConfig := &tls.Config{InsecureSkipVerify: true} tlsConfig := getTLSConfig(t)
if err = client.StartTLS(tlsConfig); err == nil { if err = client.StartTLS(tlsConfig); err == nil {
t.Error("STARTTLS should fail on EHLO") t.Error("STARTTLS should fail on EHLO")
} }
@ -1786,7 +1777,7 @@ func TestClient_StartTLS(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
}) })
tlsConfig := &tls.Config{InsecureSkipVerify: true} tlsConfig := getTLSConfig(t)
if err = client.StartTLS(tlsConfig); err == nil { if err = client.StartTLS(tlsConfig); err == nil {
t.Error("STARTTLS should fail for server not supporting it") t.Error("STARTTLS should fail for server not supporting it")
} }
@ -1821,7 +1812,8 @@ func TestClient_TLSConnectionState(t *testing.T) {
t.Errorf("failed to close client: %s", err) t.Errorf("failed to close client: %s", err)
} }
}) })
tlsConfig := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12} tlsConfig := getTLSConfig(t)
tlsConfig.MinVersion = tls.VersionTLS12
if err = client.StartTLS(tlsConfig); err != nil { if err = client.StartTLS(tlsConfig); err != nil {
t.Errorf("failed to initialize STARTTLS session: %s", err) t.Errorf("failed to initialize STARTTLS session: %s", err)
} }
@ -2472,6 +2464,79 @@ func TestClient_Rcpt(t *testing.T) {
}) })
} }
func TestClient_Data(t *testing.T) {
t.Run("normal mail data transmission succeeds", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
PortAdder.Add(1)
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-DSN\r\n250 STARTTLS"
go func() {
if err := simpleSMTPServer(ctx, t, &serverProps{
FeatureSet: featureSet,
ListenPort: serverPort,
},
); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 30)
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
if err != nil {
t.Fatalf("failed to dial to test server: %s", err)
}
t.Cleanup(func() {
if err = client.Close(); err != nil {
t.Errorf("failed to close client: %s", err)
}
})
writer, err := client.Data()
if err != nil {
t.Fatalf("failed to create data writer: %s", err)
}
t.Cleanup(func() {
if err = writer.Close(); err != nil {
t.Errorf("failed to close data writer: %s", err)
}
})
if _, err = writer.Write([]byte("test message")); err != nil {
t.Errorf("failed to write data to test server: %s", err)
}
})
t.Run("mail data transmission fails on DATA command", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
PortAdder.Add(1)
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-DSN\r\n250 STARTTLS"
go func() {
if err := simpleSMTPServer(ctx, t, &serverProps{
FailOnDataInit: true,
FeatureSet: featureSet,
ListenPort: serverPort,
},
); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 30)
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
if err != nil {
t.Fatalf("failed to dial to test server: %s", err)
}
t.Cleanup(func() {
if err = client.Close(); err != nil {
t.Errorf("failed to close client: %s", err)
}
})
if _, err = client.Data(); err == nil {
t.Error("expected data writer to fail")
}
})
}
/* /*
func TestBasic(t *testing.T) { func TestBasic(t *testing.T) {
server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") server := strings.Join(strings.Split(basicServer, "\n"), "\r\n")
@ -4251,7 +4316,7 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
break break
} }
writeLine("220 Ready to start TLS") writeLine("220 Ready to start TLS")
tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}} tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}, ServerName: "example.com"}
connection = tls.Server(connection, tlsConfig) connection = tls.Server(connection, tlsConfig)
props.IsTLS = true props.IsTLS = true
handleTestServerConnection(connection, t, props) handleTestServerConnection(connection, t, props)
@ -4444,3 +4509,18 @@ type failWriter struct{}
func (w *failWriter) Write([]byte) (int, error) { func (w *failWriter) Write([]byte) (int, error) {
return 0, errors.New("broken writer") return 0, errors.New("broken writer")
} }
func getTLSConfig(t *testing.T) *tls.Config {
t.Helper()
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
t.Fatalf("unable to load host certifcate: %s", err)
}
testRootCAs := x509.NewCertPool()
testRootCAs.AppendCertsFromPEM(localhostCert)
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: testRootCAs,
ServerName: "example.com",
}
}