diff --git a/client_test.go b/client_test.go index fa16e4b..288156c 100644 --- a/client_test.go +++ b/client_test.go @@ -34,7 +34,7 @@ const ( // TestServerAddr is the address the simple SMTP test server listens on TestServerAddr = "127.0.0.1" // TestServerPortBase is the base port for the simple SMTP test server - TestServerPortBase = 2025 + TestServerPortBase = 12025 // TestPasswordValid is the password that the test server accepts as valid for SMTP auth TestPasswordValid = "V3ryS3cr3t+" // TestUserValid is the username that the test server accepts as valid for SMTP auth @@ -44,6 +44,44 @@ const ( // PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances. var PortAdder atomic.Int32 +// localhostCert is a PEM-encoded TLS cert generated from src/crypto/tls: +// +// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com \ +// --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var localhostCert = []byte(` +-----BEGIN CERTIFICATE----- +MIICFDCCAX2gAwIBAgIRAK0xjnaPuNDSreeXb+z+0u4wDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2 +MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw +gYkCgYEA0nFbQQuOWsjbGtejcpWz153OlziZM4bVjJ9jYruNw5n2Ry6uYQAffhqa +JOInCmmcVe2siJglsyH9aRh6vKiobBbIUXXUU1ABd56ebAzlt0LobLlx7pZEMy30 +LqIi9E6zmL3YvdGzpYlkFRnRrqwEtWYbGBf3znO250S56CCWH2UCAwEAAaNoMGYw +DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF +MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA +AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAbZtDS2dVuBYvb+MnolWnCNqvw1w5Gtgi +NmvQQPOMgM3m+oQSCPRTNGSg25e1Qbo7bgQDv8ZTnq8FgOJ/rbkyERw2JckkHpD4 +n4qcK27WkEDBtQFlPihIM8hLIuzWoi/9wygiElTy/tVL3y7fGCvY2/k1KBthtZGF +tN8URjVmyEo= +-----END CERTIFICATE-----`) + +// localhostKey is the private key for localhostCert. +var localhostKey = []byte(testingKey(` +-----BEGIN RSA TESTING KEY----- +MIICXgIBAAKBgQDScVtBC45ayNsa16NylbPXnc6XOJkzhtWMn2Niu43DmfZHLq5h +AB9+Gpok4icKaZxV7ayImCWzIf1pGHq8qKhsFshRddRTUAF3np5sDOW3QuhsuXHu +lkQzLfQuoiL0TrOYvdi90bOliWQVGdGurAS1ZhsYF/fOc7bnRLnoIJYfZQIDAQAB +AoGBAMst7OgpKyFV6c3JwyI/jWqxDySL3caU+RuTTBaodKAUx2ZEmNJIlx9eudLA +kucHvoxsM/eRxlxkhdFxdBcwU6J+zqooTnhu/FE3jhrT1lPrbhfGhyKnUrB0KKMM +VY3IQZyiehpxaeXAwoAou6TbWoTpl9t8ImAqAMY8hlULCUqlAkEA+9+Ry5FSYK/m +542LujIcCaIGoG1/Te6Sxr3hsPagKC2rH20rDLqXwEedSFOpSS0vpzlPAzy/6Rbb +PHTJUhNdwwJBANXkA+TkMdbJI5do9/mn//U0LfrCR9NkcoYohxfKz8JuhgRQxzF2 +6jpo3q7CdTuuRixLWVfeJzcrAyNrVcBq87cCQFkTCtOMNC7fZnCTPUv+9q1tcJyB +vNjJu3yvoEZeIeuzouX9TJE21/33FaeDdsXbRhQEj23cqR38qFHsF1qAYNMCQQDP +QXLEiJoClkR2orAmqjPLVhR3t2oB3INcnEjLNSq8LHyQEfXyaFfu4U9l5+fRPL2i +jiC0k/9L5dHUsF0XZothAkEA23ddgRs+Id/HxtojqqUT27B8MT/IGNrYsp4DvS/c +qgkeluku4GjxRlDMBuXk94xOBEinUs+p/hwP1Alll80Tpg== +-----END RSA TESTING KEY-----`)) + // logLine represents a log entry with time, level, message, and direction details. type logLine struct { Time time.Time `json:"time"` @@ -1668,7 +1706,7 @@ func TestClient_DialWithContext(t *testing.T) { ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) t.Cleanup(cancelDial) - client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS), WithDebugLog()) + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) if err != nil { t.Fatalf("failed to create new client: %s", err) } @@ -1744,13 +1782,13 @@ func TestClient_DialWithContext(t *testing.T) { } }) t.Run("connect should fail on HELO", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctxFail, cancelFail := context.WithCancel(context.Background()) + defer cancelFail() PortAdder.Add(1) failServerPort := int(TestServerPortBase + PortAdder.Load()) failFeatureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" go func() { - if err := simpleSMTPServer(ctx, &serverProps{ + if err := simpleSMTPServer(ctxFail, &serverProps{ FailOnHelo: true, FeatureSet: failFeatureSet, ListenPort: failServerPort, @@ -1778,7 +1816,113 @@ func TestClient_DialWithContext(t *testing.T) { t.Errorf("client has no connection") } }) - // TODO: Implement tests for TLS/SSL and custom DialCtxFunc + t.Run("connect with failing auth", func(t *testing.T) { + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS), + WithSMTPAuth(SMTPAuthPlain), WithUsername("invalid"), WithPassword("invalid")) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + if err = client.DialWithContext(ctxDial); err == nil { + t.Fatalf("connection was supposed to fail, but didn't") + } + }) + t.Run("connect with STARTTLS", func(t *testing.T) { + ctxTLS, cancelTLS := context.WithCancel(context.Background()) + defer cancelTLS() + PortAdder.Add(1) + tlsServerPort := int(TestServerPortBase + PortAdder.Load()) + tlsFeatureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250-STARTTLS\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctxTLS, &serverProps{ + FeatureSet: tlsFeatureSet, + ListenPort: tlsServerPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + client, err := NewClient(DefaultHost, WithPort(tlsServerPort), WithTLSPolicy(TLSMandatory), + WithTLSConfig(tlsConfig), WithSMTPAuth(SMTPAuthPlain), WithUsername(TestUserValid), + WithPassword(TestPasswordValid)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + if err = client.DialWithContext(ctxDial); err != nil { + t.Fatalf("failed to connect to the test server: %s", err) + } + }) + t.Run("want STARTTLS, but server does not support it", func(t *testing.T) { + ctxTLS, cancelTLS := context.WithCancel(context.Background()) + defer cancelTLS() + PortAdder.Add(1) + tlsServerPort := int(TestServerPortBase + PortAdder.Load()) + tlsFeatureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctxTLS, &serverProps{ + FeatureSet: tlsFeatureSet, + ListenPort: tlsServerPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + client, err := NewClient(DefaultHost, WithPort(tlsServerPort), WithTLSPolicy(TLSMandatory), + WithTLSConfig(tlsConfig), WithSMTPAuth(SMTPAuthPlain), WithUsername(TestUserValid), + WithPassword(TestPasswordValid)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + if err = client.DialWithContext(ctxDial); err == nil { + t.Fatalf("connection was supposed to fail, but didn't") + } + }) + t.Run("connect with SSL", func(t *testing.T) { + ctxSSL, cancelSSL := context.WithCancel(context.Background()) + defer cancelSSL() + PortAdder.Add(1) + sslServerPort := int(TestServerPortBase + PortAdder.Load()) + sslFeatureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctxSSL, &serverProps{ + SSLListener: true, + FeatureSet: sslFeatureSet, + ListenPort: sslServerPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + tlsConfig := &tls.Config{InsecureSkipVerify: true} + client, err := NewClient(DefaultHost, WithPort(sslServerPort), WithSSL(), + WithTLSConfig(tlsConfig), WithSMTPAuth(SMTPAuthPlain), WithUsername(TestUserValid), + WithPassword(TestPasswordValid)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + if err = client.DialWithContext(ctxDial); err != nil { + t.Fatalf("failed to connect to the test server: %s", err) + } + if err := client.Close(); err != nil { + t.Fatalf("failed to close client: %s", err) + } + }) } /* @@ -3641,12 +3785,19 @@ func parseJSONLog(t *testing.T, buf *bytes.Buffer) logData { return logdata } +// testingKey replaces the substring "TESTING KEY" with "PRIVATE KEY" in the given string s. +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } + +// serverProps represents the configuration properties for the SMTP server. type serverProps struct { - FailOnHelo bool - FailOnQuit bool - FailOnReset bool - FeatureSet string - ListenPort int + FailOnHelo bool + FailOnQuit bool + FailOnReset bool + FailOnSTARTTLS bool + FeatureSet string + ListenPort int + SSLListener bool + IsTLS bool } // simpleSMTPServer starts a simple TCP server that resonds to SMTP commands. @@ -3656,9 +3807,23 @@ func simpleSMTPServer(ctx context.Context, props *serverProps) error { if props == nil { return fmt.Errorf("no server properties provided") } - listener, err := net.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort)) + + var listener net.Listener + var err error + if props.SSLListener { + keypair, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + return fmt.Errorf("failed to read TLS keypair: %s", err) + } + tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}} + listener, err = tls.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort), + tlsConfig) + } else { + listener, err = net.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort)) + } if err != nil { - return fmt.Errorf("unable to listen on %s://%s: %w", TestServerProto, TestServerAddr, err) + return fmt.Errorf("unable to listen on %s://%s: %w (SSL: %t)", TestServerProto, TestServerAddr, err, + props.SSLListener) } defer func() { @@ -3707,9 +3872,11 @@ func handleTestServerConnection(connection net.Conn, props *serverProps) { _ = writeLine("250 2.0.0 OK") } - if err := writeLine("220 go-mail test server ready ESMTP"); err != nil { - fmt.Printf("unable to write to client: %s\n", err) - return + if !props.IsTLS { + if err := writeLine("220 go-mail test server ready ESMTP"); err != nil { + fmt.Printf("unable to write to client: %s\n", err) + return + } } for { @@ -3850,6 +4017,21 @@ func handleTestServerConnection(connection net.Conn, props *serverProps) { } _ = writeLine("221 2.0.0 Bye") return + case strings.EqualFold(data, "starttls"): + if props.FailOnSTARTTLS { + _ = writeLine("500 5.1.2 Error: starttls failed") + break + } + keypair, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + _ = writeLine("500 5.1.2 Error: starttls failed - " + err.Error()) + break + } + _ = writeLine("220 Ready to start TLS") + tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}} + connection = tls.Server(connection, tlsConfig) + props.IsTLS = true + handleTestServerConnection(connection, props) default: _ = writeLine("500 5.5.2 Error: bad syntax") }