From 580981b15881f08afa832021364e35b9e4c8aa3b Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Wed, 2 Oct 2024 18:02:23 +0200 Subject: [PATCH 1/5] Refactor error handling in SMTP authentication Centralized error definitions in `smtp/auth.go` and updated references in `auth_login.go` and `auth_plain.go`. This improves code maintainability and error consistency across the package. --- smtp/auth.go | 13 +++++++++++++ smtp/auth_login.go | 8 ++------ smtp/auth_plain.go | 10 +++------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/smtp/auth.go b/smtp/auth.go index 30948e1..a62e74d 100644 --- a/smtp/auth.go +++ b/smtp/auth.go @@ -13,6 +13,19 @@ package smtp +import "errors" + +var ( + // ErrUnencrypted is an error indicating that the connection is not encrypted. + ErrUnencrypted = errors.New("unencrypted connection") + // ErrUnexpectedServerChallange is an error indicating that the server issued an unexpected challenge. + ErrUnexpectedServerChallange = errors.New("unexpected server challenge") + // ErrUnexpectedServerResponse is an error indicating that the server issued an unexpected response. + ErrUnexpectedServerResponse = errors.New("unexpected server response") + // ErrWrongHostname is an error indicating that the provided hostname does not match the expected value. + ErrWrongHostname = errors.New("wrong host name") +) + // Auth is implemented by an SMTP authentication mechanism. type Auth interface { // Start begins an authentication with a server. diff --git a/smtp/auth_login.go b/smtp/auth_login.go index 715861c..847ad62 100644 --- a/smtp/auth_login.go +++ b/smtp/auth_login.go @@ -5,13 +5,9 @@ package smtp import ( - "errors" "fmt" ) -// ErrUnencrypted is an error indicating that the connection is not encrypted. -var ErrUnencrypted = errors.New("unencrypted connection") - // loginAuth is the type that satisfies the Auth interface for the "SMTP LOGIN" auth type loginAuth struct { username, password string @@ -55,7 +51,7 @@ func (a *loginAuth) Start(server *ServerInfo) (string, []byte, error) { return "", nil, ErrUnencrypted } if server.Name != a.host { - return "", nil, errors.New("wrong host name") + return "", nil, ErrWrongHostname } a.respStep = 0 return "LOGIN", nil, nil @@ -73,7 +69,7 @@ func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { a.respStep++ return []byte(a.password), nil default: - return nil, fmt.Errorf("unexpected server response: %s", string(fromServer)) + return nil, fmt.Errorf("%w: %s", ErrUnexpectedServerResponse, string(fromServer)) } } return nil, nil diff --git a/smtp/auth_plain.go b/smtp/auth_plain.go index 2430c96..e6e0ad9 100644 --- a/smtp/auth_plain.go +++ b/smtp/auth_plain.go @@ -13,10 +13,6 @@ package smtp -import ( - "errors" -) - // plainAuth is the type that satisfies the Auth interface for the "SMTP PLAIN" auth type plainAuth struct { identity, username, password string @@ -42,10 +38,10 @@ func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { // That might just be the attacker saying // "it's ok, you can trust me with your password." if !server.TLS && !isLocalhost(server.Name) { - return "", nil, errors.New("unencrypted connection") + return "", nil, ErrUnencrypted } if server.Name != a.host { - return "", nil, errors.New("wrong host name") + return "", nil, ErrWrongHostname } resp := []byte(a.identity + "\x00" + a.username + "\x00" + a.password) return "PLAIN", resp, nil @@ -54,7 +50,7 @@ func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { func (a *plainAuth) Next(_ []byte, more bool) ([]byte, error) { if more { // We've already sent everything. - return nil, errors.New("unexpected server challenge") + return nil, ErrUnexpectedServerChallange } return nil, nil } From e4dd62475a2acacd1a431cd1feff866231b225ee Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Wed, 2 Oct 2024 18:02:34 +0200 Subject: [PATCH 2/5] Improve error handling in SCRAM-SHA-X-PLUS authentication Refactor error return to include more specific information and add a check for TLS connection state in SCRAM-SHA-X-PLUS authentication flow. This ensures clearer error messages and verifies essential prerequisites for secure authentication. --- smtp/auth_scram.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/smtp/auth_scram.go b/smtp/auth_scram.go index c70b210..a21aef5 100644 --- a/smtp/auth_scram.go +++ b/smtp/auth_scram.go @@ -112,7 +112,7 @@ func (a *scramAuth) Next(fromServer []byte, more bool) ([]byte, error) { return resp, nil default: a.reset() - return nil, errors.New("unexpected server response") + return nil, fmt.Errorf("%w: %s", ErrUnexpectedServerResponse, string(fromServer)) } } return nil, nil @@ -147,6 +147,9 @@ func (a *scramAuth) initialClientMessage() ([]byte, error) { // SCRAM-SHA-X-PLUS auth requires channel binding if a.isPlus { + if a.tlsConnState == nil { + return nil, errors.New("tls connection state is required for SCRAM-SHA-X-PLUS") + } bindType := "tls-unique" connState := a.tlsConnState bindData := connState.TLSUnique From a8e89a125829f3643fdbe45140437856e15b206a Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Wed, 2 Oct 2024 18:02:46 +0200 Subject: [PATCH 3/5] Add support for SCRAM-SHA authentication mechanisms Introduced new test cases for SCRAM-SHA-1, SCRAM-SHA-256, and their PLUS variants in `smtp_test.go`. Updated the authTest structure to include a `hasNonce` flag and implemented logic to handle nonce validation and success message processing. --- smtp/smtp_test.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index d5b02a7..1848b5c 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -38,6 +38,7 @@ type authTest struct { name string responses []string sf []bool + hasNonce bool } var authTests = []authTest{ @@ -47,6 +48,7 @@ var authTests = []authTest{ "PLAIN", []string{"\x00user\x00pass"}, []bool{false, false}, + false, }, { PlainAuth("foo", "bar", "baz", "testserver"), @@ -54,6 +56,15 @@ var authTests = []authTest{ "PLAIN", []string{"foo\x00bar\x00baz"}, []bool{false, false}, + false, + }, + { + PlainAuth("foo", "bar", "baz", "testserver"), + []string{"foo"}, + "PLAIN", + []string{"foo\x00bar\x00baz", ""}, + []bool{true}, + false, }, { LoginAuth("user", "pass", "testserver"), @@ -61,6 +72,7 @@ var authTests = []authTest{ "LOGIN", []string{"", "user", "pass"}, []bool{false, false}, + false, }, { LoginAuth("user", "pass", "testserver"), @@ -68,6 +80,7 @@ var authTests = []authTest{ "LOGIN", []string{"", "user", "pass"}, []bool{false, false}, + false, }, { LoginAuth("user", "pass", "testserver"), @@ -75,6 +88,7 @@ var authTests = []authTest{ "LOGIN", []string{"", "user", "pass"}, []bool{false, false}, + false, }, { LoginAuth("user", "pass", "testserver"), @@ -82,6 +96,7 @@ var authTests = []authTest{ "LOGIN", []string{"", "user", "pass", ""}, []bool{false, false, true}, + false, }, { CRAMMD5Auth("user", "pass"), @@ -89,6 +104,7 @@ var authTests = []authTest{ "CRAM-MD5", []string{"", "user 287eb355114cf5c471c26a875f1ca4ae"}, []bool{false, false}, + false, }, { XOAuth2Auth("username", "token"), @@ -96,6 +112,47 @@ var authTests = []authTest{ "XOAUTH2", []string{"user=username\x01auth=Bearer token\x01\x01", ""}, []bool{false}, + false, + }, + { + ScramSHA1Auth("username", "password"), + []string{"", "r=foo"}, + "SCRAM-SHA-1", + []string{"", "n,,n=username,r=", ""}, + []bool{false, true}, + true, + }, + { + ScramSHA1Auth("username", "password"), + []string{"", "v=foo"}, + "SCRAM-SHA-1", + []string{"", "n,,n=username,r=", ""}, + []bool{false, true}, + true, + }, + { + ScramSHA256Auth("username", "password"), + []string{""}, + "SCRAM-SHA-256", + []string{"", "n,,n=username,r=", ""}, + []bool{false}, + true, + }, + { + ScramSHA1PlusAuth("username", "password", nil), + []string{""}, + "SCRAM-SHA-1-PLUS", + []string{"", "", ""}, + []bool{true}, + true, + }, + { + ScramSHA256PlusAuth("username", "password", nil), + []string{""}, + "SCRAM-SHA-256-PLUS", + []string{"", "", ""}, + []bool{true}, + true, }, } @@ -121,10 +178,20 @@ testLoop: t.Errorf("#%d error: %s", i, err) continue testLoop } + if test.hasNonce { + if !bytes.HasPrefix(resp, expected) { + t.Errorf("#%d got response: %s, expected response to start with: %s", i, resp, expected) + } + continue testLoop + } if !bytes.Equal(resp, expected) { t.Errorf("#%d got %s, expected %s", i, resp, expected) continue testLoop } + _, err = test.auth.Next([]byte("2.7.0 Authentication successful"), false) + if err != nil { + t.Errorf("#%d success message error: %s", i, err) + } } } } From 03062c5183b71f247820ab8e6b0edcb1a823e47b Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Thu, 3 Oct 2024 12:32:06 +0200 Subject: [PATCH 4/5] Add SCRAM-SHA authentication tests for SMTP Introduce new unit tests to verify SCRAM-SHA-1 and SCRAM-SHA-256 authentication for the SMTP client. These tests cover both successful and failing authentication cases, and include a mock SMTP server to facilitate testing. --- smtp/smtp_test.go | 288 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 288 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 1848b5c..dc8dbdb 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -18,6 +18,7 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "encoding/base64" "flag" "fmt" "io" @@ -368,6 +369,106 @@ func TestXOAuth2Error(t *testing.T) { } } +func TestAuthSCRAMSHA1_OK(t *testing.T) { + hostname := "127.0.0.1" + port := "2585" + + go func() { + startSMTPServer(false, hostname, port) + }() + time.Sleep(time.Millisecond * 500) + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + if err = client.Auth(ScramSHA1Auth("username", "password")); err != nil { + t.Errorf("failed to authenticate: %v", err) + } +} + +func TestAuthSCRAMSHA256_OK(t *testing.T) { + hostname := "127.0.0.1" + port := "2586" + + go func() { + startSMTPServer(false, hostname, port) + }() + time.Sleep(time.Millisecond * 500) + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + if err = client.Auth(ScramSHA256Auth("username", "password")); err != nil { + t.Errorf("failed to authenticate: %v", err) + } +} + +func TestAuthSCRAMSHA1_fail(t *testing.T) { + hostname := "127.0.0.1" + port := "2587" + + go func() { + startSMTPServer(true, hostname, port) + }() + time.Sleep(time.Millisecond * 500) + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + if err = client.Auth(ScramSHA1Auth("username", "password")); err == nil { + t.Errorf("expected auth error, got nil") + } +} + +func TestAuthSCRAMSHA256_fail(t *testing.T) { + hostname := "127.0.0.1" + port := "2588" + + go func() { + startSMTPServer(true, hostname, port) + }() + time.Sleep(time.Millisecond * 500) + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + if err = client.Auth(ScramSHA256Auth("username", "password")); err == nil { + t.Errorf("expected auth error, got nil") + } +} + // Issue 17794: don't send a trailing space on AUTH command when there's no password. func TestClientAuthTrimSpace(t *testing.T) { server := "220 hello world\r\n" + @@ -1541,3 +1642,190 @@ func SkipFlaky(t testing.TB, issue int) { t.Skipf("skipping known flaky test without the -flaky flag; see golang.org/issue/%d", issue) } } + +// testSCRAMSMTPServer represents a test server for SCRAM-based SMTP authentication. +// It does not do any acutal computation of the challanges but verifies that the expected +// fields are present. We have actual real authentication tests for all SCRAM modes in the +// go-mail client_test.go +type testSCRAMSMTPServer struct { + authMechanism string + nonce string + hostname string + port string + shouldFail bool +} + +func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) { + defer func() { + _ = conn.Close() + }() + + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + writeLine := func(data string) error { + _, err := writer.WriteString(data + "\r\n") + if err != nil { + return fmt.Errorf("unable to write line: %w", err) + } + return writer.Flush() + } + writeOK := func() { + _ = writeLine("250 2.0.0 OK") + } + + if err := writeLine("220 go-mail test server ready ESMTP"); err != nil { + return + } + + data, err := reader.ReadString('\n') + if err != nil { + return + } + data = strings.TrimSpace(data) + if strings.HasPrefix(data, "EHLO") { + _ = writeLine(fmt.Sprintf("250-%s", s.hostname)) + _ = writeLine("250-AUTH SCRAM-SHA-1 SCRAM-SHA-256") + writeOK() + } else { + _ = writeLine("500 Invalid command") + return + } + + for { + data, err = reader.ReadString('\n') + if err != nil { + fmt.Printf("failed to read data: %v", err) + } + data = strings.TrimSpace(data) + if strings.HasPrefix(data, "AUTH") { + parts := strings.Split(data, " ") + if len(parts) < 2 { + _ = writeLine("500 Syntax error") + return + } + + authMechanism := parts[1] + if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" { + _ = writeLine("504 Unrecognized authentication mechanism") + return + } + s.authMechanism = authMechanism + _ = writeLine("334 ") + s.handleSCRAMAuth(conn) + return + } else { + _ = writeLine("500 Invalid command") + } + } +} + +func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + writeLine := func(data string) error { + _, err := writer.WriteString(data + "\r\n") + if err != nil { + return fmt.Errorf("unable to write line: %w", err) + } + return writer.Flush() + } + + data, err := reader.ReadString('\n') + clientMessage := strings.TrimSpace(data) + decodedMessage, err := base64.StdEncoding.DecodeString(clientMessage) + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + splits := strings.Split(string(decodedMessage), ",") + if len(splits) != 4 { + _ = writeLine("535 Authentication failed - expected 4 parts") + return + } + if splits[0] != "n" { + _ = writeLine("535 Authentication failed - expected n to be in the first part") + return + } + if splits[2] != "n=username" { + _ = writeLine("535 Authentication failed - expected n=username to be in the third part") + return + } + if !strings.HasPrefix(splits[3], "r=") { + _ = writeLine("535 Authentication failed - expected r= to be in the fourth part") + return + } + clientNonce := s.extractNonce(string(decodedMessage)) + if clientNonce == "" { + _ = writeLine("535 Authentication failed") + return + } + + s.nonce = clientNonce + "server_nonce" + serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=0", s.nonce, "salt") + _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage)))) + + data, err = reader.ReadString('\n') + clientFinalMessage := strings.TrimSpace(data) + decodedFinalMessage, err := base64.StdEncoding.DecodeString(clientFinalMessage) + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + splits = strings.Split(string(decodedFinalMessage), ",") + if splits[0] != "c=biws" { + _ = writeLine("535 Authentication failed - expected c=biws to be in the first part") + return + } + if !strings.HasPrefix(splits[1], "r=") { + _ = writeLine("535 Authentication failed - expected r to be in the second part") + return + } + if !strings.Contains(splits[1], "server_nonce") { + _ = writeLine("535 Authentication failed - expected server_nonce to be in the second part") + return + } + if !strings.HasPrefix(splits[2], "p=") { + _ = writeLine("535 Authentication failed - expected p to be in the third part") + return + } + + if s.shouldFail { + _ = writeLine("535 Authentication failed") + return + } + _ = writeLine("235 Authentication successful") + return +} + +func (s *testSCRAMSMTPServer) extractNonce(message string) string { + parts := strings.Split(message, ",") + for _, part := range parts { + if strings.HasPrefix(part, "r=") { + return part[2:] + } + } + return "" +} + +func startSMTPServer(shouldFail bool, hostname, port string) { + server := &testSCRAMSMTPServer{ + hostname: hostname, + port: port, + shouldFail: shouldFail, + } + listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", hostname, port)) + if err != nil { + fmt.Printf("Failed to start SMTP server: %v", err) + } + defer func() { + _ = listener.Close() + }() + for { + conn, err := listener.Accept() + if err != nil { + fmt.Printf("Failed to accept connection: %v", err) + continue + } + go server.handleConnection(conn) + } +} From 4c8c0d855e206ea3135960f8fc403e093763205d Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Thu, 3 Oct 2024 12:38:39 +0200 Subject: [PATCH 5/5] Handle read errors in SMTP authentication flow Add checks to handle errors when reading client messages. This ensures that an appropriate error message is sent back to the client if reading fails, improving the robustness of the SMTP authentication process. --- smtp/smtp_test.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index dc8dbdb..0d47760 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1731,8 +1731,12 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { } data, err := reader.ReadString('\n') - clientMessage := strings.TrimSpace(data) - decodedMessage, err := base64.StdEncoding.DecodeString(clientMessage) + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + data = strings.TrimSpace(data) + decodedMessage, err := base64.StdEncoding.DecodeString(data) if err != nil { _ = writeLine("535 Authentication failed") return @@ -1765,8 +1769,12 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage)))) data, err = reader.ReadString('\n') - clientFinalMessage := strings.TrimSpace(data) - decodedFinalMessage, err := base64.StdEncoding.DecodeString(clientFinalMessage) + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + data = strings.TrimSpace(data) + decodedFinalMessage, err := base64.StdEncoding.DecodeString(data) if err != nil { _ = writeLine("535 Authentication failed") return @@ -1794,7 +1802,6 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { return } _ = writeLine("235 Authentication successful") - return } func (s *testSCRAMSMTPServer) extractNonce(message string) string {