diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 238176a..da7ad8b 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1391,252 +1391,204 @@ func TestScramAuth_handleServerFirstResponse(t *testing.T) { t.Errorf("expected error to be %q, got %q", expectedErr, err) } }) - } -/* +func TestCRAMMD5Auth(t *testing.T) { + t.Run("CRAM-MD5 on test server succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH CRAM-MD5\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + auth := CRAMMD5Auth("username", "password") + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Auth(auth); err != nil { + t.Errorf("failed to auth to test server: %s", err) + } + }) + t.Run("CRAM-MD5 on test server fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH CRAM-MD5\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) - - - - -func TestAuthSCRAMSHA1_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2585" - - go func() { - startSMTPServer(false, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA1Auth("username", "password")); err != nil { - t.Errorf("failed to authenticate: %v", err) - } + auth := CRAMMD5Auth("username", "password") + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Error("auth should fail on test server") + } + }) } -func TestAuthSCRAMSHA256_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2586" +func TestNewClient(t *testing.T) { + t.Run("new client via Dial succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) - go func() { - startSMTPServer(false, hostname, port, sha256.New) - }() - time.Sleep(time.Millisecond * 500) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to create client: %s", err) + } + if err := client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + t.Run("new client via Dial fails on server not started", func(t *testing.T) { + _, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, 64000)) + if err == nil { + t.Error("dial on non-existant server should fail") + } + }) + t.Run("new client fails on server not available", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnDial: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA256Auth("username", "password")); err != nil { - t.Errorf("failed to authenticate: %v", err) - } + _, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err == nil { + t.Error("connection to non-available server should fail") + } + }) + t.Run("new client fails on faker that fails on close", func(t *testing.T) { + server := "442 service not available\r\n" + var wrote strings.Builder + var fake faker + fake.failOnClose = true + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(server), + &wrote, + } + _, err := NewClient(fake, "faker.host") + if err == nil { + t.Error("connection to non-available server should fail on close") + } + }) } -func TestAuthSCRAMSHA1PLUS_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2590" +func TestClient_hello(t *testing.T) { + t.Run("client fails on EHLO but not on HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) - go func() { - startSMTPServer(true, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA1PlusAuth("username", "password", &tlsConnState)); err != nil { - t.Errorf("failed to authenticate: %v", err) - } + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.hello(); err == nil { + t.Error("helo should fail on test server") + } + }) } -func TestAuthSCRAMSHA256PLUS_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2591" +func TestClient_Hello(t *testing.T) { + t.Run("normal client HELO/EHLO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) - go func() { - startSMTPServer(true, hostname, port, sha256.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA256PlusAuth("username", "password", &tlsConnState)); err != nil { - t.Errorf("failed to authenticate: %v", err) - } -} - -func TestAuthSCRAMSHA1_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2587" - - go func() { - startSMTPServer(false, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA1Auth("username", "invalid")); err == nil { - t.Errorf("expected auth error, got nil") - } -} - -func TestAuthSCRAMSHA256_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2588" - - go func() { - startSMTPServer(false, hostname, port, sha256.New) - }() - time.Sleep(time.Millisecond * 500) - - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA256Auth("username", "invalid")); err == nil { - t.Errorf("expected auth error, got nil") - } -} - -func TestAuthSCRAMSHA1PLUS_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2592" - - go func() { - startSMTPServer(true, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA1PlusAuth("username", "invalid", &tlsConnState)); err == nil { - t.Errorf("expected auth error, got nil") - } -} - -func TestAuthSCRAMSHA256PLUS_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2593" - - go func() { - startSMTPServer(true, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA256PlusAuth("username", "invalid", &tlsConnState)); err == nil { - t.Errorf("expected auth error, got nil") - } + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Hello(TestServerAddr); err != nil { + t.Errorf("failed to send HELO/EHLO to test server: %s", err) + } + }) } // Issue 17794: don't send a trailing space on AUTH command when there's no password. -func TestClientAuthTrimSpace(t *testing.T) { +func TestClient_Auth_trimSpace(t *testing.T) { server := "220 hello world\r\n" + "200 some more" var wrote strings.Builder @@ -1655,7 +1607,7 @@ func TestClientAuthTrimSpace(t *testing.T) { c.tls = true c.didHello = true _ = c.Auth(toServerEmptyAuth{}) - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Errorf("close failed: %s", err) } if got, want := wrote.String(), "AUTH FOOAUTH\r\n*\r\nQUIT\r\n"; got != want { @@ -1663,19 +1615,15 @@ func TestClientAuthTrimSpace(t *testing.T) { } } -// toServerEmptyAuth is an implementation of Auth that only implements -// the Start method, and returns "FOOAUTH", nil, nil. Notably, it returns -// zero bytes for "toServer" so we can test that we don't send spaces at -// the end of the line. See TestClientAuthTrimSpace. -type toServerEmptyAuth struct{} +/* + + + + + + -func (toServerEmptyAuth) Start(_ *ServerInfo) (proto string, toServer []byte, err error) { - return "FOOAUTH", nil, nil -} -func (toServerEmptyAuth) Next(_ []byte, _ bool) (toServer []byte, err error) { - panic("unexpected call") -} func TestBasic(t *testing.T) { server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") @@ -3158,46 +3106,6 @@ func sendMail(hostPort string) error { return SendMail(hostPort, nil, from, to, []byte("Subject: test\n\nhowdy!")) } -// localhostCert is a PEM-encoded TLS cert generated from src/crypto/tls: -// -// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com \ -// --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h -var localhostCert = []byte(` ------BEGIN CERTIFICATE----- -MIICFDCCAX2gAwIBAgIRAK0xjnaPuNDSreeXb+z+0u4wDQYJKoZIhvcNAQELBQAw -EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2 -MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw -gYkCgYEA0nFbQQuOWsjbGtejcpWz153OlziZM4bVjJ9jYruNw5n2Ry6uYQAffhqa -JOInCmmcVe2siJglsyH9aRh6vKiobBbIUXXUU1ABd56ebAzlt0LobLlx7pZEMy30 -LqIi9E6zmL3YvdGzpYlkFRnRrqwEtWYbGBf3znO250S56CCWH2UCAwEAAaNoMGYw -DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF -MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA -AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAbZtDS2dVuBYvb+MnolWnCNqvw1w5Gtgi -NmvQQPOMgM3m+oQSCPRTNGSg25e1Qbo7bgQDv8ZTnq8FgOJ/rbkyERw2JckkHpD4 -n4qcK27WkEDBtQFlPihIM8hLIuzWoi/9wygiElTy/tVL3y7fGCvY2/k1KBthtZGF -tN8URjVmyEo= ------END CERTIFICATE-----`) - -// localhostKey is the private key for localhostCert. -var localhostKey = []byte(testingKey(` ------BEGIN RSA TESTING KEY----- -MIICXgIBAAKBgQDScVtBC45ayNsa16NylbPXnc6XOJkzhtWMn2Niu43DmfZHLq5h -AB9+Gpok4icKaZxV7ayImCWzIf1pGHq8qKhsFshRddRTUAF3np5sDOW3QuhsuXHu -lkQzLfQuoiL0TrOYvdi90bOliWQVGdGurAS1ZhsYF/fOc7bnRLnoIJYfZQIDAQAB -AoGBAMst7OgpKyFV6c3JwyI/jWqxDySL3caU+RuTTBaodKAUx2ZEmNJIlx9eudLA -kucHvoxsM/eRxlxkhdFxdBcwU6J+zqooTnhu/FE3jhrT1lPrbhfGhyKnUrB0KKMM -VY3IQZyiehpxaeXAwoAou6TbWoTpl9t8ImAqAMY8hlULCUqlAkEA+9+Ry5FSYK/m -542LujIcCaIGoG1/Te6Sxr3hsPagKC2rH20rDLqXwEedSFOpSS0vpzlPAzy/6Rbb -PHTJUhNdwwJBANXkA+TkMdbJI5do9/mn//U0LfrCR9NkcoYohxfKz8JuhgRQxzF2 -6jpo3q7CdTuuRixLWVfeJzcrAyNrVcBq87cCQFkTCtOMNC7fZnCTPUv+9q1tcJyB -vNjJu3yvoEZeIeuzouX9TJE21/33FaeDdsXbRhQEj23cqR38qFHsF1qAYNMCQQDP -QXLEiJoClkR2orAmqjPLVhR3t2oB3INcnEjLNSq8LHyQEfXyaFfu4U9l5+fRPL2i -jiC0k/9L5dHUsF0XZothAkEA23ddgRs+Id/HxtojqqUT27B8MT/IGNrYsp4DvS/c -qgkeluku4GjxRlDMBuXk94xOBEinUs+p/hwP1Alll80Tpg== ------END RSA TESTING KEY-----`)) - -func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } - var flaky = flag.Bool("flaky", false, "run known-flaky tests too") func SkipFlaky(t testing.TB, issue int) { @@ -3207,271 +3115,22 @@ func SkipFlaky(t testing.TB, issue int) { } } -// testSCRAMSMTPServer represents a test server for SCRAM-based SMTP authentication. -// It does not do any acutal computation of the challenges but verifies that the expected -// fields are present. We have actual real authentication tests for all SCRAM modes in the -// go-mail client_test.go -type testSCRAMSMTPServer struct { - authMechanism string - nonce string - hostname string - port string - tlsServer bool - h func() hash.Hash -} - -func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) { - defer func() { - _ = conn.Close() - }() - - reader := bufio.NewReader(conn) - writer := bufio.NewWriter(conn) - writeLine := func(data string) error { - _, err := writer.WriteString(data + "\r\n") - if err != nil { - return fmt.Errorf("unable to write line: %w", err) - } - return writer.Flush() - } - writeOK := func() { - _ = writeLine("250 2.0.0 OK") - } - - if err := writeLine("220 go-mail test server ready ESMTP"); err != nil { - return - } - - data, err := reader.ReadString('\n') - if err != nil { - return - } - data = strings.TrimSpace(data) - if strings.HasPrefix(data, "EHLO") { - _ = writeLine(fmt.Sprintf("250-%s", s.hostname)) - _ = writeLine("250-AUTH SCRAM-SHA-1 SCRAM-SHA-256") - writeOK() - } else { - _ = writeLine("500 Invalid command") - return - } - - for { - data, err = reader.ReadString('\n') - if err != nil { - fmt.Printf("failed to read data: %v", err) - } - data = strings.TrimSpace(data) - if strings.HasPrefix(data, "AUTH") { - parts := strings.Split(data, " ") - if len(parts) < 2 { - _ = writeLine("500 Syntax error") - return - } - - authMechanism := parts[1] - if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" && - authMechanism != "SCRAM-SHA-1-PLUS" && authMechanism != "SCRAM-SHA-256-PLUS" { - _ = writeLine("504 Unrecognized authentication mechanism") - return - } - s.authMechanism = authMechanism - _ = writeLine("334 ") - s.handleSCRAMAuth(conn) - return - } else { - _ = writeLine("500 Invalid command") - } - } -} - -func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { - reader := bufio.NewReader(conn) - writer := bufio.NewWriter(conn) - writeLine := func(data string) error { - _, err := writer.WriteString(data + "\r\n") - if err != nil { - return fmt.Errorf("unable to write line: %w", err) - } - return writer.Flush() - } - var authMsg string - - data, err := reader.ReadString('\n') - if err != nil { - _ = writeLine("535 Authentication failed") - return - } - data = strings.TrimSpace(data) - decodedMessage, err := base64.StdEncoding.DecodeString(data) - if err != nil { - _ = writeLine("535 Authentication failed") - return - } - splits := strings.Split(string(decodedMessage), ",") - if len(splits) != 4 { - _ = writeLine("535 Authentication failed - expected 4 parts") - return - } - if !s.tlsServer && splits[0] != "n" { - _ = writeLine("535 Authentication failed - expected n to be in the first part") - return - } - if s.tlsServer && !strings.HasPrefix(splits[0], "p=") { - _ = writeLine("535 Authentication failed - expected p= to be in the first part") - return - } - if splits[2] != "n=username" { - _ = writeLine("535 Authentication failed - expected n=username to be in the third part") - return - } - if !strings.HasPrefix(splits[3], "r=") { - _ = writeLine("535 Authentication failed - expected r= to be in the fourth part") - return - } - authMsg = splits[2] + "," + splits[3] - - clientNonce := s.extractNonce(string(decodedMessage)) - if clientNonce == "" { - _ = writeLine("535 Authentication failed") - return - } - - s.nonce = clientNonce + "server_nonce" - serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=4096", s.nonce, - base64.StdEncoding.EncodeToString([]byte("salt"))) - _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage)))) - authMsg = authMsg + "," + serverFirstMessage - - data, err = reader.ReadString('\n') - if err != nil { - _ = writeLine("535 Authentication failed") - return - } - data = strings.TrimSpace(data) - decodedFinalMessage, err := base64.StdEncoding.DecodeString(data) - if err != nil { - _ = writeLine("535 Authentication failed") - return - } - splits = strings.Split(string(decodedFinalMessage), ",") - - if !s.tlsServer && splits[0] != "c=biws" { - _ = writeLine("535 Authentication failed - expected c=biws to be in the first part") - return - } - if s.tlsServer { - if !strings.HasPrefix(splits[0], "c=") { - _ = writeLine("535 Authentication failed - expected c= to be in the first part") - return - } - channelBind, err := base64.StdEncoding.DecodeString(splits[0][2:]) - if err != nil { - _ = writeLine("535 Authentication failed - base64 channel bind is not valid - " + err.Error()) - return - } - if !strings.HasPrefix(string(channelBind), "p=") { - _ = writeLine("535 Authentication failed - expected channel binding to start with p=-") - return - } - cbType := string(channelBind[2:]) - if !strings.HasPrefix(cbType, "tls-unique") && !strings.HasPrefix(cbType, "tls-exporter") { - _ = writeLine("535 Authentication failed - expected channel binding type tls-unique or tls-exporter") - return - } - } - - if !strings.HasPrefix(splits[1], "r=") { - _ = writeLine("535 Authentication failed - expected r to be in the second part") - return - } - if !strings.Contains(splits[1], "server_nonce") { - _ = writeLine("535 Authentication failed - expected server_nonce to be in the second part") - return - } - if !strings.HasPrefix(splits[2], "p=") { - _ = writeLine("535 Authentication failed - expected p to be in the third part") - return - } - - authMsg = authMsg + "," + splits[0] + "," + splits[1] - saltedPwd := pbkdf2.Key([]byte("password"), []byte("salt"), 4096, s.h().Size(), s.h) - mac := hmac.New(s.h, saltedPwd) - mac.Write([]byte("Server Key")) - skey := mac.Sum(nil) - mac.Reset() - - mac = hmac.New(s.h, skey) - mac.Write([]byte(authMsg)) - ssig := mac.Sum(nil) - mac.Reset() - - serverFinalMessage := fmt.Sprintf("v=%s", base64.StdEncoding.EncodeToString(ssig)) - _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFinalMessage)))) - - _, err = reader.ReadString('\n') - if err != nil { - _ = writeLine("535 Authentication failed") - return - } - - _ = writeLine("235 Authentication successful") -} - -func (s *testSCRAMSMTPServer) extractNonce(message string) string { - parts := strings.Split(message, ",") - for _, part := range parts { - if strings.HasPrefix(part, "r=") { - return part[2:] - } - } - return "" -} - -func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash) { - server := &testSCRAMSMTPServer{ - hostname: hostname, - port: port, - tlsServer: tlsServer, - h: h, - } - listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - fmt.Printf("Failed to start SMTP server: %v", err) - } - defer func() { - _ = listener.Close() - }() - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}} - - for { - conn, err := listener.Accept() - if err != nil { - fmt.Printf("Failed to accept connection: %v", err) - continue - } - if server.tlsServer { - conn = tls.Server(conn, &tlsConfig) - } - go server.handleConnection(conn) - } -} - */ // faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter + failOnRead bool + failOnClose bool } -func (f faker) Close() error { return nil } +func (f faker) Close() error { + if f.failOnClose { + return fmt.Errorf("faker: failed to close connection") + } + return nil +} func (f faker) LocalAddr() net.Addr { return nil } func (f faker) RemoteAddr() net.Addr { return nil } func (f faker) SetDeadline(time.Time) error { return nil } @@ -3486,6 +3145,8 @@ type serverProps struct { FailOnAuth bool FailOnDataInit bool FailOnDataClose bool + FailOnDial bool + FailOnEhlo bool FailOnHelo bool FailOnMailFrom bool FailOnNoop bool @@ -3582,6 +3243,10 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server } if !props.IsTLS { + if props.FailOnDial { + writeLine("421 4.4.1 Service not available") + return + } writeLine("220 go-mail test server ready ESMTP") } @@ -3595,9 +3260,9 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server var datastring string data = strings.TrimSpace(data) switch { - case strings.HasPrefix(data, "EHLO"), strings.HasPrefix(data, "HELO"): + case strings.HasPrefix(data, "HELO"): if len(strings.Split(data, " ")) != 2 { - writeLine("501 Syntax: EHLO hostname") + writeLine("501 Syntax: HELO hostname") break } if props.FailOnHelo { @@ -3605,6 +3270,16 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server break } writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) + case strings.HasPrefix(data, "EHLO"): + if len(strings.Split(data, " ")) != 2 { + writeLine("501 Syntax: EHLO hostname") + break + } + if props.FailOnEhlo { + writeLine("500 5.5.2 Error: fail on EHLO") + break + } + writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) case strings.HasPrefix(data, "MAIL FROM:"): if props.FailOnMailFrom { writeLine("500 5.5.2 Error: fail on MAIL FROM") @@ -3886,3 +3561,17 @@ type randReader struct{} func (r *randReader) Read([]byte) (int, error) { return 0, errors.New("broken reader") } + +// toServerEmptyAuth is an implementation of Auth that only implements +// the Start method, and returns "FOOAUTH", nil, nil. Notably, it returns +// zero bytes for "toServer" so we can test that we don't send spaces at +// the end of the line. See TestClientAuthTrimSpace. +type toServerEmptyAuth struct{} + +func (toServerEmptyAuth) Start(_ *ServerInfo) (proto string, toServer []byte, err error) { + return "FOOAUTH", nil, nil +} + +func (toServerEmptyAuth) Next(_ []byte, _ bool) (toServer []byte, err error) { + return nil, fmt.Errorf("unexpected call") +}