mirror of
https://github.com/wneessen/go-mail.git
synced 2024-11-15 02:12:55 +01:00
Refactor TLS config initialization in tests
Replace repetitive TLS configuration code with a reusable `getTLSConfig` helper function for consistency and maintainability. Additionally, update port configuration and add new tests for mail data transmission.
This commit is contained in:
parent
75bfdd2855
commit
e9c7bdbb4e
1 changed files with 100 additions and 20 deletions
|
@ -22,6 +22,7 @@ import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -42,7 +43,7 @@ const (
|
||||||
// TestServerAddr is the address the simple SMTP test server listens on
|
// TestServerAddr is the address the simple SMTP test server listens on
|
||||||
TestServerAddr = "127.0.0.1"
|
TestServerAddr = "127.0.0.1"
|
||||||
// TestServerPortBase is the base port for the simple SMTP test server
|
// TestServerPortBase is the base port for the simple SMTP test server
|
||||||
TestServerPortBase = 12025
|
TestServerPortBase = 30025
|
||||||
)
|
)
|
||||||
|
|
||||||
// PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances.
|
// PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances.
|
||||||
|
@ -1087,13 +1088,8 @@ func TestScramAuth(t *testing.T) {
|
||||||
var client *Client
|
var client *Client
|
||||||
switch tt.tls {
|
switch tt.tls {
|
||||||
case true:
|
case true:
|
||||||
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
|
tlsConfig := getTLSConfig(t)
|
||||||
if err != nil {
|
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig)
|
||||||
fmt.Printf("error creating TLS cert: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
|
|
||||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to dial TLS server: %v", err)
|
t.Fatalf("failed to dial TLS server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1161,13 +1157,8 @@ func TestScramAuth(t *testing.T) {
|
||||||
var client *Client
|
var client *Client
|
||||||
switch tt.tls {
|
switch tt.tls {
|
||||||
case true:
|
case true:
|
||||||
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
|
tlsConfig := getTLSConfig(t)
|
||||||
if err != nil {
|
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig)
|
||||||
fmt.Printf("error creating TLS cert: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
|
|
||||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to dial TLS server: %v", err)
|
t.Fatalf("failed to dial TLS server: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1719,7 +1710,7 @@ func TestClient_StartTLS(t *testing.T) {
|
||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
tlsConfig := getTLSConfig(t)
|
||||||
if err = client.StartTLS(tlsConfig); err != nil {
|
if err = client.StartTLS(tlsConfig); err != nil {
|
||||||
t.Errorf("failed to initialize STARTTLS session: %s", err)
|
t.Errorf("failed to initialize STARTTLS session: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -1753,7 +1744,7 @@ func TestClient_StartTLS(t *testing.T) {
|
||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
tlsConfig := getTLSConfig(t)
|
||||||
if err = client.StartTLS(tlsConfig); err == nil {
|
if err = client.StartTLS(tlsConfig); err == nil {
|
||||||
t.Error("STARTTLS should fail on EHLO")
|
t.Error("STARTTLS should fail on EHLO")
|
||||||
}
|
}
|
||||||
|
@ -1786,7 +1777,7 @@ func TestClient_StartTLS(t *testing.T) {
|
||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
tlsConfig := &tls.Config{InsecureSkipVerify: true}
|
tlsConfig := getTLSConfig(t)
|
||||||
if err = client.StartTLS(tlsConfig); err == nil {
|
if err = client.StartTLS(tlsConfig); err == nil {
|
||||||
t.Error("STARTTLS should fail for server not supporting it")
|
t.Error("STARTTLS should fail for server not supporting it")
|
||||||
}
|
}
|
||||||
|
@ -1821,7 +1812,8 @@ func TestClient_TLSConnectionState(t *testing.T) {
|
||||||
t.Errorf("failed to close client: %s", err)
|
t.Errorf("failed to close client: %s", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
tlsConfig := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12}
|
tlsConfig := getTLSConfig(t)
|
||||||
|
tlsConfig.MinVersion = tls.VersionTLS12
|
||||||
if err = client.StartTLS(tlsConfig); err != nil {
|
if err = client.StartTLS(tlsConfig); err != nil {
|
||||||
t.Errorf("failed to initialize STARTTLS session: %s", err)
|
t.Errorf("failed to initialize STARTTLS session: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -2472,6 +2464,79 @@ func TestClient_Rcpt(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClient_Data(t *testing.T) {
|
||||||
|
t.Run("normal mail data transmission succeeds", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
PortAdder.Add(1)
|
||||||
|
serverPort := int(TestServerPortBase + PortAdder.Load())
|
||||||
|
featureSet := "250-DSN\r\n250 STARTTLS"
|
||||||
|
go func() {
|
||||||
|
if err := simpleSMTPServer(ctx, t, &serverProps{
|
||||||
|
FeatureSet: featureSet,
|
||||||
|
ListenPort: serverPort,
|
||||||
|
},
|
||||||
|
); err != nil {
|
||||||
|
t.Errorf("failed to start test server: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
time.Sleep(time.Millisecond * 30)
|
||||||
|
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to dial to test server: %s", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if err = client.Close(); err != nil {
|
||||||
|
t.Errorf("failed to close client: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
writer, err := client.Data()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create data writer: %s", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if err = writer.Close(); err != nil {
|
||||||
|
t.Errorf("failed to close data writer: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if _, err = writer.Write([]byte("test message")); err != nil {
|
||||||
|
t.Errorf("failed to write data to test server: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("mail data transmission fails on DATA command", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
PortAdder.Add(1)
|
||||||
|
serverPort := int(TestServerPortBase + PortAdder.Load())
|
||||||
|
featureSet := "250-DSN\r\n250 STARTTLS"
|
||||||
|
go func() {
|
||||||
|
if err := simpleSMTPServer(ctx, t, &serverProps{
|
||||||
|
FailOnDataInit: true,
|
||||||
|
FeatureSet: featureSet,
|
||||||
|
ListenPort: serverPort,
|
||||||
|
},
|
||||||
|
); err != nil {
|
||||||
|
t.Errorf("failed to start test server: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
time.Sleep(time.Millisecond * 30)
|
||||||
|
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to dial to test server: %s", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if err = client.Close(); err != nil {
|
||||||
|
t.Errorf("failed to close client: %s", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if _, err = client.Data(); err == nil {
|
||||||
|
t.Error("expected data writer to fail")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
func TestBasic(t *testing.T) {
|
func TestBasic(t *testing.T) {
|
||||||
server := strings.Join(strings.Split(basicServer, "\n"), "\r\n")
|
server := strings.Join(strings.Split(basicServer, "\n"), "\r\n")
|
||||||
|
@ -4251,7 +4316,7 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
writeLine("220 Ready to start TLS")
|
writeLine("220 Ready to start TLS")
|
||||||
tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}}
|
tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}, ServerName: "example.com"}
|
||||||
connection = tls.Server(connection, tlsConfig)
|
connection = tls.Server(connection, tlsConfig)
|
||||||
props.IsTLS = true
|
props.IsTLS = true
|
||||||
handleTestServerConnection(connection, t, props)
|
handleTestServerConnection(connection, t, props)
|
||||||
|
@ -4444,3 +4509,18 @@ type failWriter struct{}
|
||||||
func (w *failWriter) Write([]byte) (int, error) {
|
func (w *failWriter) Write([]byte) (int, error) {
|
||||||
return 0, errors.New("broken writer")
|
return 0, errors.New("broken writer")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTLSConfig(t *testing.T) *tls.Config {
|
||||||
|
t.Helper()
|
||||||
|
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to load host certifcate: %s", err)
|
||||||
|
}
|
||||||
|
testRootCAs := x509.NewCertPool()
|
||||||
|
testRootCAs.AppendCertsFromPEM(localhostCert)
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
RootCAs: testRootCAs,
|
||||||
|
ServerName: "example.com",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue