diff --git a/smtp/auth.go b/smtp/auth.go index 30948e1..a62e74d 100644 --- a/smtp/auth.go +++ b/smtp/auth.go @@ -13,6 +13,19 @@ package smtp +import "errors" + +var ( + // ErrUnencrypted is an error indicating that the connection is not encrypted. + ErrUnencrypted = errors.New("unencrypted connection") + // ErrUnexpectedServerChallange is an error indicating that the server issued an unexpected challenge. + ErrUnexpectedServerChallange = errors.New("unexpected server challenge") + // ErrUnexpectedServerResponse is an error indicating that the server issued an unexpected response. + ErrUnexpectedServerResponse = errors.New("unexpected server response") + // ErrWrongHostname is an error indicating that the provided hostname does not match the expected value. + ErrWrongHostname = errors.New("wrong host name") +) + // Auth is implemented by an SMTP authentication mechanism. type Auth interface { // Start begins an authentication with a server. diff --git a/smtp/auth_login.go b/smtp/auth_login.go index 715861c..847ad62 100644 --- a/smtp/auth_login.go +++ b/smtp/auth_login.go @@ -5,13 +5,9 @@ package smtp import ( - "errors" "fmt" ) -// ErrUnencrypted is an error indicating that the connection is not encrypted. -var ErrUnencrypted = errors.New("unencrypted connection") - // loginAuth is the type that satisfies the Auth interface for the "SMTP LOGIN" auth type loginAuth struct { username, password string @@ -55,7 +51,7 @@ func (a *loginAuth) Start(server *ServerInfo) (string, []byte, error) { return "", nil, ErrUnencrypted } if server.Name != a.host { - return "", nil, errors.New("wrong host name") + return "", nil, ErrWrongHostname } a.respStep = 0 return "LOGIN", nil, nil @@ -73,7 +69,7 @@ func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { a.respStep++ return []byte(a.password), nil default: - return nil, fmt.Errorf("unexpected server response: %s", string(fromServer)) + return nil, fmt.Errorf("%w: %s", ErrUnexpectedServerResponse, string(fromServer)) } } return nil, nil diff --git a/smtp/auth_plain.go b/smtp/auth_plain.go index 2430c96..e6e0ad9 100644 --- a/smtp/auth_plain.go +++ b/smtp/auth_plain.go @@ -13,10 +13,6 @@ package smtp -import ( - "errors" -) - // plainAuth is the type that satisfies the Auth interface for the "SMTP PLAIN" auth type plainAuth struct { identity, username, password string @@ -42,10 +38,10 @@ func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { // That might just be the attacker saying // "it's ok, you can trust me with your password." if !server.TLS && !isLocalhost(server.Name) { - return "", nil, errors.New("unencrypted connection") + return "", nil, ErrUnencrypted } if server.Name != a.host { - return "", nil, errors.New("wrong host name") + return "", nil, ErrWrongHostname } resp := []byte(a.identity + "\x00" + a.username + "\x00" + a.password) return "PLAIN", resp, nil @@ -54,7 +50,7 @@ func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { func (a *plainAuth) Next(_ []byte, more bool) ([]byte, error) { if more { // We've already sent everything. - return nil, errors.New("unexpected server challenge") + return nil, ErrUnexpectedServerChallange } return nil, nil } diff --git a/smtp/auth_scram.go b/smtp/auth_scram.go index c70b210..a21aef5 100644 --- a/smtp/auth_scram.go +++ b/smtp/auth_scram.go @@ -112,7 +112,7 @@ func (a *scramAuth) Next(fromServer []byte, more bool) ([]byte, error) { return resp, nil default: a.reset() - return nil, errors.New("unexpected server response") + return nil, fmt.Errorf("%w: %s", ErrUnexpectedServerResponse, string(fromServer)) } } return nil, nil @@ -147,6 +147,9 @@ func (a *scramAuth) initialClientMessage() ([]byte, error) { // SCRAM-SHA-X-PLUS auth requires channel binding if a.isPlus { + if a.tlsConnState == nil { + return nil, errors.New("tls connection state is required for SCRAM-SHA-X-PLUS") + } bindType := "tls-unique" connState := a.tlsConnState bindData := connState.TLSUnique diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index d5b02a7..0d47760 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -18,6 +18,7 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "encoding/base64" "flag" "fmt" "io" @@ -38,6 +39,7 @@ type authTest struct { name string responses []string sf []bool + hasNonce bool } var authTests = []authTest{ @@ -47,6 +49,7 @@ var authTests = []authTest{ "PLAIN", []string{"\x00user\x00pass"}, []bool{false, false}, + false, }, { PlainAuth("foo", "bar", "baz", "testserver"), @@ -54,6 +57,15 @@ var authTests = []authTest{ "PLAIN", []string{"foo\x00bar\x00baz"}, []bool{false, false}, + false, + }, + { + PlainAuth("foo", "bar", "baz", "testserver"), + []string{"foo"}, + "PLAIN", + []string{"foo\x00bar\x00baz", ""}, + []bool{true}, + false, }, { LoginAuth("user", "pass", "testserver"), @@ -61,6 +73,7 @@ var authTests = []authTest{ "LOGIN", []string{"", "user", "pass"}, []bool{false, false}, + false, }, { LoginAuth("user", "pass", "testserver"), @@ -68,6 +81,7 @@ var authTests = []authTest{ "LOGIN", []string{"", "user", "pass"}, []bool{false, false}, + false, }, { LoginAuth("user", "pass", "testserver"), @@ -75,6 +89,7 @@ var authTests = []authTest{ "LOGIN", []string{"", "user", "pass"}, []bool{false, false}, + false, }, { LoginAuth("user", "pass", "testserver"), @@ -82,6 +97,7 @@ var authTests = []authTest{ "LOGIN", []string{"", "user", "pass", ""}, []bool{false, false, true}, + false, }, { CRAMMD5Auth("user", "pass"), @@ -89,6 +105,7 @@ var authTests = []authTest{ "CRAM-MD5", []string{"", "user 287eb355114cf5c471c26a875f1ca4ae"}, []bool{false, false}, + false, }, { XOAuth2Auth("username", "token"), @@ -96,6 +113,47 @@ var authTests = []authTest{ "XOAUTH2", []string{"user=username\x01auth=Bearer token\x01\x01", ""}, []bool{false}, + false, + }, + { + ScramSHA1Auth("username", "password"), + []string{"", "r=foo"}, + "SCRAM-SHA-1", + []string{"", "n,,n=username,r=", ""}, + []bool{false, true}, + true, + }, + { + ScramSHA1Auth("username", "password"), + []string{"", "v=foo"}, + "SCRAM-SHA-1", + []string{"", "n,,n=username,r=", ""}, + []bool{false, true}, + true, + }, + { + ScramSHA256Auth("username", "password"), + []string{""}, + "SCRAM-SHA-256", + []string{"", "n,,n=username,r=", ""}, + []bool{false}, + true, + }, + { + ScramSHA1PlusAuth("username", "password", nil), + []string{""}, + "SCRAM-SHA-1-PLUS", + []string{"", "", ""}, + []bool{true}, + true, + }, + { + ScramSHA256PlusAuth("username", "password", nil), + []string{""}, + "SCRAM-SHA-256-PLUS", + []string{"", "", ""}, + []bool{true}, + true, }, } @@ -121,10 +179,20 @@ testLoop: t.Errorf("#%d error: %s", i, err) continue testLoop } + if test.hasNonce { + if !bytes.HasPrefix(resp, expected) { + t.Errorf("#%d got response: %s, expected response to start with: %s", i, resp, expected) + } + continue testLoop + } if !bytes.Equal(resp, expected) { t.Errorf("#%d got %s, expected %s", i, resp, expected) continue testLoop } + _, err = test.auth.Next([]byte("2.7.0 Authentication successful"), false) + if err != nil { + t.Errorf("#%d success message error: %s", i, err) + } } } } @@ -301,6 +369,106 @@ func TestXOAuth2Error(t *testing.T) { } } +func TestAuthSCRAMSHA1_OK(t *testing.T) { + hostname := "127.0.0.1" + port := "2585" + + go func() { + startSMTPServer(false, hostname, port) + }() + time.Sleep(time.Millisecond * 500) + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + if err = client.Auth(ScramSHA1Auth("username", "password")); err != nil { + t.Errorf("failed to authenticate: %v", err) + } +} + +func TestAuthSCRAMSHA256_OK(t *testing.T) { + hostname := "127.0.0.1" + port := "2586" + + go func() { + startSMTPServer(false, hostname, port) + }() + time.Sleep(time.Millisecond * 500) + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + if err = client.Auth(ScramSHA256Auth("username", "password")); err != nil { + t.Errorf("failed to authenticate: %v", err) + } +} + +func TestAuthSCRAMSHA1_fail(t *testing.T) { + hostname := "127.0.0.1" + port := "2587" + + go func() { + startSMTPServer(true, hostname, port) + }() + time.Sleep(time.Millisecond * 500) + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + if err = client.Auth(ScramSHA1Auth("username", "password")); err == nil { + t.Errorf("expected auth error, got nil") + } +} + +func TestAuthSCRAMSHA256_fail(t *testing.T) { + hostname := "127.0.0.1" + port := "2588" + + go func() { + startSMTPServer(true, hostname, port) + }() + time.Sleep(time.Millisecond * 500) + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + if err = client.Auth(ScramSHA256Auth("username", "password")); err == nil { + t.Errorf("expected auth error, got nil") + } +} + // Issue 17794: don't send a trailing space on AUTH command when there's no password. func TestClientAuthTrimSpace(t *testing.T) { server := "220 hello world\r\n" + @@ -1474,3 +1642,197 @@ func SkipFlaky(t testing.TB, issue int) { t.Skipf("skipping known flaky test without the -flaky flag; see golang.org/issue/%d", issue) } } + +// testSCRAMSMTPServer represents a test server for SCRAM-based SMTP authentication. +// It does not do any acutal computation of the challanges but verifies that the expected +// fields are present. We have actual real authentication tests for all SCRAM modes in the +// go-mail client_test.go +type testSCRAMSMTPServer struct { + authMechanism string + nonce string + hostname string + port string + shouldFail bool +} + +func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) { + defer func() { + _ = conn.Close() + }() + + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + writeLine := func(data string) error { + _, err := writer.WriteString(data + "\r\n") + if err != nil { + return fmt.Errorf("unable to write line: %w", err) + } + return writer.Flush() + } + writeOK := func() { + _ = writeLine("250 2.0.0 OK") + } + + if err := writeLine("220 go-mail test server ready ESMTP"); err != nil { + return + } + + data, err := reader.ReadString('\n') + if err != nil { + return + } + data = strings.TrimSpace(data) + if strings.HasPrefix(data, "EHLO") { + _ = writeLine(fmt.Sprintf("250-%s", s.hostname)) + _ = writeLine("250-AUTH SCRAM-SHA-1 SCRAM-SHA-256") + writeOK() + } else { + _ = writeLine("500 Invalid command") + return + } + + for { + data, err = reader.ReadString('\n') + if err != nil { + fmt.Printf("failed to read data: %v", err) + } + data = strings.TrimSpace(data) + if strings.HasPrefix(data, "AUTH") { + parts := strings.Split(data, " ") + if len(parts) < 2 { + _ = writeLine("500 Syntax error") + return + } + + authMechanism := parts[1] + if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" { + _ = writeLine("504 Unrecognized authentication mechanism") + return + } + s.authMechanism = authMechanism + _ = writeLine("334 ") + s.handleSCRAMAuth(conn) + return + } else { + _ = writeLine("500 Invalid command") + } + } +} + +func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + writeLine := func(data string) error { + _, err := writer.WriteString(data + "\r\n") + if err != nil { + return fmt.Errorf("unable to write line: %w", err) + } + return writer.Flush() + } + + data, err := reader.ReadString('\n') + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + data = strings.TrimSpace(data) + decodedMessage, err := base64.StdEncoding.DecodeString(data) + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + splits := strings.Split(string(decodedMessage), ",") + if len(splits) != 4 { + _ = writeLine("535 Authentication failed - expected 4 parts") + return + } + if splits[0] != "n" { + _ = writeLine("535 Authentication failed - expected n to be in the first part") + return + } + if splits[2] != "n=username" { + _ = writeLine("535 Authentication failed - expected n=username to be in the third part") + return + } + if !strings.HasPrefix(splits[3], "r=") { + _ = writeLine("535 Authentication failed - expected r= to be in the fourth part") + return + } + clientNonce := s.extractNonce(string(decodedMessage)) + if clientNonce == "" { + _ = writeLine("535 Authentication failed") + return + } + + s.nonce = clientNonce + "server_nonce" + serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=0", s.nonce, "salt") + _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage)))) + + data, err = reader.ReadString('\n') + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + data = strings.TrimSpace(data) + decodedFinalMessage, err := base64.StdEncoding.DecodeString(data) + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + splits = strings.Split(string(decodedFinalMessage), ",") + if splits[0] != "c=biws" { + _ = writeLine("535 Authentication failed - expected c=biws to be in the first part") + return + } + if !strings.HasPrefix(splits[1], "r=") { + _ = writeLine("535 Authentication failed - expected r to be in the second part") + return + } + if !strings.Contains(splits[1], "server_nonce") { + _ = writeLine("535 Authentication failed - expected server_nonce to be in the second part") + return + } + if !strings.HasPrefix(splits[2], "p=") { + _ = writeLine("535 Authentication failed - expected p to be in the third part") + return + } + + if s.shouldFail { + _ = writeLine("535 Authentication failed") + return + } + _ = writeLine("235 Authentication successful") +} + +func (s *testSCRAMSMTPServer) extractNonce(message string) string { + parts := strings.Split(message, ",") + for _, part := range parts { + if strings.HasPrefix(part, "r=") { + return part[2:] + } + } + return "" +} + +func startSMTPServer(shouldFail bool, hostname, port string) { + server := &testSCRAMSMTPServer{ + hostname: hostname, + port: port, + shouldFail: shouldFail, + } + listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", hostname, port)) + if err != nil { + fmt.Printf("Failed to start SMTP server: %v", err) + } + defer func() { + _ = listener.Close() + }() + for { + conn, err := listener.Accept() + if err != nil { + fmt.Printf("Failed to accept connection: %v", err) + continue + } + go server.handleConnection(conn) + } +}