Add STARTTLS and SSL test cases for SMTP client

Extended test cases to include scenarios for STARTTLS and SSL connections. This includes handling TLS configurations, testing certificate handling, and tests for various authentication methods under TLS.
This commit is contained in:
Winni Neessen 2024-10-23 23:33:33 +02:00
parent 563ccbab4a
commit c63b8b124e
Signed by: wneessen
GPG key ID: 385AC9889632126E

View file

@ -34,7 +34,7 @@ const (
// TestServerAddr is the address the simple SMTP test server listens on // TestServerAddr is the address the simple SMTP test server listens on
TestServerAddr = "127.0.0.1" TestServerAddr = "127.0.0.1"
// TestServerPortBase is the base port for the simple SMTP test server // TestServerPortBase is the base port for the simple SMTP test server
TestServerPortBase = 2025 TestServerPortBase = 12025
// TestPasswordValid is the password that the test server accepts as valid for SMTP auth // TestPasswordValid is the password that the test server accepts as valid for SMTP auth
TestPasswordValid = "V3ryS3cr3t+" TestPasswordValid = "V3ryS3cr3t+"
// TestUserValid is the username that the test server accepts as valid for SMTP auth // TestUserValid is the username that the test server accepts as valid for SMTP auth
@ -44,6 +44,44 @@ const (
// PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances. // PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances.
var PortAdder atomic.Int32 var PortAdder atomic.Int32
// localhostCert is a PEM-encoded TLS cert generated from src/crypto/tls:
//
// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com \
// --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var localhostCert = []byte(`
-----BEGIN CERTIFICATE-----
MIICFDCCAX2gAwIBAgIRAK0xjnaPuNDSreeXb+z+0u4wDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
gYkCgYEA0nFbQQuOWsjbGtejcpWz153OlziZM4bVjJ9jYruNw5n2Ry6uYQAffhqa
JOInCmmcVe2siJglsyH9aRh6vKiobBbIUXXUU1ABd56ebAzlt0LobLlx7pZEMy30
LqIi9E6zmL3YvdGzpYlkFRnRrqwEtWYbGBf3znO250S56CCWH2UCAwEAAaNoMGYw
DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF
MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA
AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAbZtDS2dVuBYvb+MnolWnCNqvw1w5Gtgi
NmvQQPOMgM3m+oQSCPRTNGSg25e1Qbo7bgQDv8ZTnq8FgOJ/rbkyERw2JckkHpD4
n4qcK27WkEDBtQFlPihIM8hLIuzWoi/9wygiElTy/tVL3y7fGCvY2/k1KBthtZGF
tN8URjVmyEo=
-----END CERTIFICATE-----`)
// localhostKey is the private key for localhostCert.
var localhostKey = []byte(testingKey(`
-----BEGIN RSA TESTING KEY-----
MIICXgIBAAKBgQDScVtBC45ayNsa16NylbPXnc6XOJkzhtWMn2Niu43DmfZHLq5h
AB9+Gpok4icKaZxV7ayImCWzIf1pGHq8qKhsFshRddRTUAF3np5sDOW3QuhsuXHu
lkQzLfQuoiL0TrOYvdi90bOliWQVGdGurAS1ZhsYF/fOc7bnRLnoIJYfZQIDAQAB
AoGBAMst7OgpKyFV6c3JwyI/jWqxDySL3caU+RuTTBaodKAUx2ZEmNJIlx9eudLA
kucHvoxsM/eRxlxkhdFxdBcwU6J+zqooTnhu/FE3jhrT1lPrbhfGhyKnUrB0KKMM
VY3IQZyiehpxaeXAwoAou6TbWoTpl9t8ImAqAMY8hlULCUqlAkEA+9+Ry5FSYK/m
542LujIcCaIGoG1/Te6Sxr3hsPagKC2rH20rDLqXwEedSFOpSS0vpzlPAzy/6Rbb
PHTJUhNdwwJBANXkA+TkMdbJI5do9/mn//U0LfrCR9NkcoYohxfKz8JuhgRQxzF2
6jpo3q7CdTuuRixLWVfeJzcrAyNrVcBq87cCQFkTCtOMNC7fZnCTPUv+9q1tcJyB
vNjJu3yvoEZeIeuzouX9TJE21/33FaeDdsXbRhQEj23cqR38qFHsF1qAYNMCQQDP
QXLEiJoClkR2orAmqjPLVhR3t2oB3INcnEjLNSq8LHyQEfXyaFfu4U9l5+fRPL2i
jiC0k/9L5dHUsF0XZothAkEA23ddgRs+Id/HxtojqqUT27B8MT/IGNrYsp4DvS/c
qgkeluku4GjxRlDMBuXk94xOBEinUs+p/hwP1Alll80Tpg==
-----END RSA TESTING KEY-----`))
// logLine represents a log entry with time, level, message, and direction details. // logLine represents a log entry with time, level, message, and direction details.
type logLine struct { type logLine struct {
Time time.Time `json:"time"` Time time.Time `json:"time"`
@ -1668,7 +1706,7 @@ func TestClient_DialWithContext(t *testing.T) {
ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500)
t.Cleanup(cancelDial) t.Cleanup(cancelDial)
client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS), WithDebugLog()) client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS))
if err != nil { if err != nil {
t.Fatalf("failed to create new client: %s", err) t.Fatalf("failed to create new client: %s", err)
} }
@ -1744,13 +1782,13 @@ func TestClient_DialWithContext(t *testing.T) {
} }
}) })
t.Run("connect should fail on HELO", func(t *testing.T) { t.Run("connect should fail on HELO", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctxFail, cancelFail := context.WithCancel(context.Background())
defer cancel() defer cancelFail()
PortAdder.Add(1) PortAdder.Add(1)
failServerPort := int(TestServerPortBase + PortAdder.Load()) failServerPort := int(TestServerPortBase + PortAdder.Load())
failFeatureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" failFeatureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() { go func() {
if err := simpleSMTPServer(ctx, &serverProps{ if err := simpleSMTPServer(ctxFail, &serverProps{
FailOnHelo: true, FailOnHelo: true,
FeatureSet: failFeatureSet, FeatureSet: failFeatureSet,
ListenPort: failServerPort, ListenPort: failServerPort,
@ -1778,7 +1816,113 @@ func TestClient_DialWithContext(t *testing.T) {
t.Errorf("client has no connection") t.Errorf("client has no connection")
} }
}) })
// TODO: Implement tests for TLS/SSL and custom DialCtxFunc t.Run("connect with failing auth", func(t *testing.T) {
ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500)
t.Cleanup(cancelDial)
client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS),
WithSMTPAuth(SMTPAuthPlain), WithUsername("invalid"), WithPassword("invalid"))
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
if err = client.DialWithContext(ctxDial); err == nil {
t.Fatalf("connection was supposed to fail, but didn't")
}
})
t.Run("connect with STARTTLS", func(t *testing.T) {
ctxTLS, cancelTLS := context.WithCancel(context.Background())
defer cancelTLS()
PortAdder.Add(1)
tlsServerPort := int(TestServerPortBase + PortAdder.Load())
tlsFeatureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250-STARTTLS\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctxTLS, &serverProps{
FeatureSet: tlsFeatureSet,
ListenPort: tlsServerPort,
}); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500)
t.Cleanup(cancelDial)
tlsConfig := &tls.Config{InsecureSkipVerify: true}
client, err := NewClient(DefaultHost, WithPort(tlsServerPort), WithTLSPolicy(TLSMandatory),
WithTLSConfig(tlsConfig), WithSMTPAuth(SMTPAuthPlain), WithUsername(TestUserValid),
WithPassword(TestPasswordValid))
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
if err = client.DialWithContext(ctxDial); err != nil {
t.Fatalf("failed to connect to the test server: %s", err)
}
})
t.Run("want STARTTLS, but server does not support it", func(t *testing.T) {
ctxTLS, cancelTLS := context.WithCancel(context.Background())
defer cancelTLS()
PortAdder.Add(1)
tlsServerPort := int(TestServerPortBase + PortAdder.Load())
tlsFeatureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctxTLS, &serverProps{
FeatureSet: tlsFeatureSet,
ListenPort: tlsServerPort,
}); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500)
t.Cleanup(cancelDial)
tlsConfig := &tls.Config{InsecureSkipVerify: true}
client, err := NewClient(DefaultHost, WithPort(tlsServerPort), WithTLSPolicy(TLSMandatory),
WithTLSConfig(tlsConfig), WithSMTPAuth(SMTPAuthPlain), WithUsername(TestUserValid),
WithPassword(TestPasswordValid))
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
if err = client.DialWithContext(ctxDial); err == nil {
t.Fatalf("connection was supposed to fail, but didn't")
}
})
t.Run("connect with SSL", func(t *testing.T) {
ctxSSL, cancelSSL := context.WithCancel(context.Background())
defer cancelSSL()
PortAdder.Add(1)
sslServerPort := int(TestServerPortBase + PortAdder.Load())
sslFeatureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctxSSL, &serverProps{
SSLListener: true,
FeatureSet: sslFeatureSet,
ListenPort: sslServerPort,
}); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500)
t.Cleanup(cancelDial)
tlsConfig := &tls.Config{InsecureSkipVerify: true}
client, err := NewClient(DefaultHost, WithPort(sslServerPort), WithSSL(),
WithTLSConfig(tlsConfig), WithSMTPAuth(SMTPAuthPlain), WithUsername(TestUserValid),
WithPassword(TestPasswordValid))
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
if err = client.DialWithContext(ctxDial); err != nil {
t.Fatalf("failed to connect to the test server: %s", err)
}
if err := client.Close(); err != nil {
t.Fatalf("failed to close client: %s", err)
}
})
} }
/* /*
@ -3641,12 +3785,19 @@ func parseJSONLog(t *testing.T, buf *bytes.Buffer) logData {
return logdata return logdata
} }
// testingKey replaces the substring "TESTING KEY" with "PRIVATE KEY" in the given string s.
func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
// serverProps represents the configuration properties for the SMTP server.
type serverProps struct { type serverProps struct {
FailOnHelo bool FailOnHelo bool
FailOnQuit bool FailOnQuit bool
FailOnReset bool FailOnReset bool
FeatureSet string FailOnSTARTTLS bool
ListenPort int FeatureSet string
ListenPort int
SSLListener bool
IsTLS bool
} }
// simpleSMTPServer starts a simple TCP server that resonds to SMTP commands. // simpleSMTPServer starts a simple TCP server that resonds to SMTP commands.
@ -3656,9 +3807,23 @@ func simpleSMTPServer(ctx context.Context, props *serverProps) error {
if props == nil { if props == nil {
return fmt.Errorf("no server properties provided") return fmt.Errorf("no server properties provided")
} }
listener, err := net.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort))
var listener net.Listener
var err error
if props.SSLListener {
keypair, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
return fmt.Errorf("failed to read TLS keypair: %s", err)
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}}
listener, err = tls.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort),
tlsConfig)
} else {
listener, err = net.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort))
}
if err != nil { if err != nil {
return fmt.Errorf("unable to listen on %s://%s: %w", TestServerProto, TestServerAddr, err) return fmt.Errorf("unable to listen on %s://%s: %w (SSL: %t)", TestServerProto, TestServerAddr, err,
props.SSLListener)
} }
defer func() { defer func() {
@ -3707,9 +3872,11 @@ func handleTestServerConnection(connection net.Conn, props *serverProps) {
_ = writeLine("250 2.0.0 OK") _ = writeLine("250 2.0.0 OK")
} }
if err := writeLine("220 go-mail test server ready ESMTP"); err != nil { if !props.IsTLS {
fmt.Printf("unable to write to client: %s\n", err) if err := writeLine("220 go-mail test server ready ESMTP"); err != nil {
return fmt.Printf("unable to write to client: %s\n", err)
return
}
} }
for { for {
@ -3850,6 +4017,21 @@ func handleTestServerConnection(connection net.Conn, props *serverProps) {
} }
_ = writeLine("221 2.0.0 Bye") _ = writeLine("221 2.0.0 Bye")
return return
case strings.EqualFold(data, "starttls"):
if props.FailOnSTARTTLS {
_ = writeLine("500 5.1.2 Error: starttls failed")
break
}
keypair, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
_ = writeLine("500 5.1.2 Error: starttls failed - " + err.Error())
break
}
_ = writeLine("220 Ready to start TLS")
tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}}
connection = tls.Server(connection, tlsConfig)
props.IsTLS = true
handleTestServerConnection(connection, props)
default: default:
_ = writeLine("500 5.5.2 Error: bad syntax") _ = writeLine("500 5.5.2 Error: bad syntax")
} }