mirror of
https://github.com/wneessen/go-mail.git
synced 2024-12-22 18:50:37 +01:00
Refactor TLS config initialization in tests
Replace repetitive TLS configuration code with a reusable `getTLSConfig` helper function for consistency and maintainability. Additionally, update port configuration and add new tests for mail data transmission.
This commit is contained in:
parent
75bfdd2855
commit
e9c7bdbb4e
1 changed files with 100 additions and 20 deletions
|
@ -22,6 +22,7 @@ import (
|
|||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -42,7 +43,7 @@ const (
|
|||
// TestServerAddr is the address the simple SMTP test server listens on
|
||||
TestServerAddr = "127.0.0.1"
|
||||
// TestServerPortBase is the base port for the simple SMTP test server
|
||||
TestServerPortBase = 12025
|
||||
TestServerPortBase = 30025
|
||||
)
|
||||
|
||||
// PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances.
|
||||
|
@ -1087,13 +1088,8 @@ func TestScramAuth(t *testing.T) {
|
|||
var client *Client
|
||||
switch tt.tls {
|
||||
case true:
|
||||
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
|
||||
if err != nil {
|
||||
fmt.Printf("error creating TLS cert: %s", err)
|
||||
return
|
||||
}
|
||||
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig)
|
||||
tlsConfig := getTLSConfig(t)
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial TLS server: %v", err)
|
||||
}
|
||||
|
@ -1161,13 +1157,8 @@ func TestScramAuth(t *testing.T) {
|
|||
var client *Client
|
||||
switch tt.tls {
|
||||
case true:
|
||||
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
|
||||
if err != nil {
|
||||
fmt.Printf("error creating TLS cert: %s", err)
|
||||
return
|
||||
}
|
||||
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig)
|
||||
tlsConfig := getTLSConfig(t)
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial TLS server: %v", err)
|
||||
}
|
||||
|
@ -1719,7 +1710,7 @@ func TestClient_StartTLS(t *testing.T) {
|
|||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
})
|
||||
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
||||
tlsConfig := getTLSConfig(t)
|
||||
if err = client.StartTLS(tlsConfig); err != nil {
|
||||
t.Errorf("failed to initialize STARTTLS session: %s", err)
|
||||
}
|
||||
|
@ -1753,7 +1744,7 @@ func TestClient_StartTLS(t *testing.T) {
|
|||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
})
|
||||
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
||||
tlsConfig := getTLSConfig(t)
|
||||
if err = client.StartTLS(tlsConfig); err == nil {
|
||||
t.Error("STARTTLS should fail on EHLO")
|
||||
}
|
||||
|
@ -1786,7 +1777,7 @@ func TestClient_StartTLS(t *testing.T) {
|
|||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
})
|
||||
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
||||
tlsConfig := getTLSConfig(t)
|
||||
if err = client.StartTLS(tlsConfig); err == nil {
|
||||
t.Error("STARTTLS should fail for server not supporting it")
|
||||
}
|
||||
|
@ -1821,7 +1812,8 @@ func TestClient_TLSConnectionState(t *testing.T) {
|
|||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
})
|
||||
tlsConfig := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12}
|
||||
tlsConfig := getTLSConfig(t)
|
||||
tlsConfig.MinVersion = tls.VersionTLS12
|
||||
if err = client.StartTLS(tlsConfig); err != nil {
|
||||
t.Errorf("failed to initialize STARTTLS session: %s", err)
|
||||
}
|
||||
|
@ -2472,6 +2464,79 @@ func TestClient_Rcpt(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestClient_Data(t *testing.T) {
|
||||
t.Run("normal mail data transmission succeeds", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
PortAdder.Add(1)
|
||||
serverPort := int(TestServerPortBase + PortAdder.Load())
|
||||
featureSet := "250-DSN\r\n250 STARTTLS"
|
||||
go func() {
|
||||
if err := simpleSMTPServer(ctx, t, &serverProps{
|
||||
FeatureSet: featureSet,
|
||||
ListenPort: serverPort,
|
||||
},
|
||||
); err != nil {
|
||||
t.Errorf("failed to start test server: %s", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 30)
|
||||
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial to test server: %s", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err = client.Close(); err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
})
|
||||
writer, err := client.Data()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create data writer: %s", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err = writer.Close(); err != nil {
|
||||
t.Errorf("failed to close data writer: %s", err)
|
||||
}
|
||||
})
|
||||
if _, err = writer.Write([]byte("test message")); err != nil {
|
||||
t.Errorf("failed to write data to test server: %s", err)
|
||||
}
|
||||
})
|
||||
t.Run("mail data transmission fails on DATA command", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
PortAdder.Add(1)
|
||||
serverPort := int(TestServerPortBase + PortAdder.Load())
|
||||
featureSet := "250-DSN\r\n250 STARTTLS"
|
||||
go func() {
|
||||
if err := simpleSMTPServer(ctx, t, &serverProps{
|
||||
FailOnDataInit: true,
|
||||
FeatureSet: featureSet,
|
||||
ListenPort: serverPort,
|
||||
},
|
||||
); err != nil {
|
||||
t.Errorf("failed to start test server: %s", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 30)
|
||||
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial to test server: %s", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err = client.Close(); err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
})
|
||||
if _, err = client.Data(); err == nil {
|
||||
t.Error("expected data writer to fail")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
func TestBasic(t *testing.T) {
|
||||
server := strings.Join(strings.Split(basicServer, "\n"), "\r\n")
|
||||
|
@ -4251,7 +4316,7 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
|
|||
break
|
||||
}
|
||||
writeLine("220 Ready to start TLS")
|
||||
tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}}
|
||||
tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}, ServerName: "example.com"}
|
||||
connection = tls.Server(connection, tlsConfig)
|
||||
props.IsTLS = true
|
||||
handleTestServerConnection(connection, t, props)
|
||||
|
@ -4444,3 +4509,18 @@ type failWriter struct{}
|
|||
func (w *failWriter) Write([]byte) (int, error) {
|
||||
return 0, errors.New("broken writer")
|
||||
}
|
||||
|
||||
func getTLSConfig(t *testing.T) *tls.Config {
|
||||
t.Helper()
|
||||
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to load host certifcate: %s", err)
|
||||
}
|
||||
testRootCAs := x509.NewCertPool()
|
||||
testRootCAs.AppendCertsFromPEM(localhostCert)
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
RootCAs: testRootCAs,
|
||||
ServerName: "example.com",
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue