From d4c6cb506c8ef8fc3828b16e107e5535a3fe860b Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 8 Nov 2024 16:53:09 +0100 Subject: [PATCH] Add SCRAM authentication tests to smtp package Added comprehensive unit tests for SCRAM-SHA-1, SCRAM-SHA-256, and their PLUS variants. Implemented a test server to simulate various SCRAM authentication scenarios and validate both success and failure cases. --- smtp/smtp_test.go | 352 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 352 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 5846e7f..70f8d04 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -17,15 +17,22 @@ import ( "bufio" "bytes" "context" + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" "crypto/tls" + "encoding/base64" "errors" "fmt" + "hash" "io" "net" "strings" "sync/atomic" "testing" "time" + + "golang.org/x/crypto/pbkdf2" ) const ( @@ -1038,6 +1045,178 @@ func TestXOAuth2Auth(t *testing.T) { }) } +func TestScramAuth(t *testing.T) { + tests := []struct { + name string + tls bool + authString string + hash func() hash.Hash + isPlus bool + }{ + {"SCRAM-SHA-1 (no TLS)", false, "SCRAM-SHA-1", sha1.New, false}, + {"SCRAM-SHA-256 (no TLS)", false, "SCRAM-SHA-256", sha256.New, false}, + {"SCRAM-SHA-1 (with TLS)", true, "SCRAM-SHA-1", sha1.New, false}, + {"SCRAM-SHA-256 (with TLS)", true, "SCRAM-SHA-256", sha256.New, false}, + {"SCRAM-SHA-1-PLUS", true, "SCRAM-SHA-1-PLUS", sha1.New, true}, + {"SCRAM-SHA-256-PLUS", true, "SCRAM-SHA-256-PLUS", sha256.New, true}, + } + for _, tt := range tests { + t.Run(tt.name+" succeeds on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := fmt.Sprintf("250-AUTH %s\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8", tt.authString) + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + TestSCRAM: true, + HashFunc: tt.hash, + FeatureSet: featureSet, + ListenPort: serverPort, + SSLListener: tt.tls, + IsSCRAMPlus: tt.isPlus, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + var client *Client + switch tt.tls { + case true: + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + fmt.Printf("error creating TLS cert: %s", err) + return + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig) + if err != nil { + t.Fatalf("failed to dial TLS server: %v", err) + } + client, err = NewClient(conn, TestServerAddr) + case false: + var err error + client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + } + t.Cleanup(func() { + if err := client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + + var auth Auth + switch tt.authString { + case "SCRAM-SHA-1": + auth = ScramSHA1Auth("username", "password") + case "SCRAM-SHA-256": + auth = ScramSHA256Auth("username", "password") + case "SCRAM-SHA-1-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA1PlusAuth("username", "password", tlsConnState) + case "SCRAM-SHA-256-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA256PlusAuth("username", "password", tlsConnState) + default: + t.Fatalf("unexpected auth string: %s", tt.authString) + } + if err := client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run(tt.name+" fails on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := fmt.Sprintf("250-AUTH %s\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8", tt.authString) + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + TestSCRAM: true, + HashFunc: tt.hash, + FeatureSet: featureSet, + ListenPort: serverPort, + SSLListener: tt.tls, + IsSCRAMPlus: tt.isPlus, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + var client *Client + switch tt.tls { + case true: + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + fmt.Printf("error creating TLS cert: %s", err) + return + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig) + if err != nil { + t.Fatalf("failed to dial TLS server: %v", err) + } + client, err = NewClient(conn, TestServerAddr) + case false: + var err error + client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + } + + var auth Auth + switch tt.authString { + case "SCRAM-SHA-1": + auth = ScramSHA1Auth("invalid", "password") + case "SCRAM-SHA-256": + auth = ScramSHA256Auth("invalid", "password") + case "SCRAM-SHA-1-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA1PlusAuth("invalid", "password", tlsConnState) + case "SCRAM-SHA-256-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA256PlusAuth("invalid", "password", tlsConnState) + default: + t.Fatalf("unexpected auth string: %s", tt.authString) + } + if err := client.Auth(auth); err == nil { + t.Error("expected authentication to fail") + } + }) + } + t.Run("ScramAuth_Next with nonsense parameter", func(t *testing.T) { + auth := ScramSHA1Auth("username", "password") + _, err := auth.Next([]byte("x=nonsense"), true) + if err == nil { + t.Fatal("expected authentication to fail") + } + if !errors.Is(err, ErrUnexpectedServerResponse) { + t.Errorf("expected ErrUnexpectedServerResponse, got %s", err) + } + }) +} + /* @@ -3140,8 +3319,11 @@ type serverProps struct { FeatureSet string ListenPort int SSLListener bool + IsSCRAMPlus bool IsTLS bool SupportDSN bool + TestSCRAM bool + HashFunc func() hash.Hash } // simpleSMTPServer starts a simple TCP server that resonds to SMTP commands. @@ -3279,6 +3461,22 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server writeLine("535 5.7.8 Error: authentication failed") break } + if props.TestSCRAM { + parts := strings.Split(data, " ") + authMechanism := parts[1] + if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" && + authMechanism != "SCRAM-SHA-1-PLUS" && authMechanism != "SCRAM-SHA-256-PLUS" { + writeLine("504 Unrecognized authentication mechanism") + break + } + scram := &testSCRAMSMTP{ + tlsServer: props.IsSCRAMPlus, + h: props.HashFunc, + } + writeLine("334 ") + scram.handleSCRAMAuth(connection) + break + } writeLine("235 2.7.0 Authentication successful") case strings.EqualFold(data, "DATA"): if props.FailOnDataInit { @@ -3348,3 +3546,157 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server } } } + +// testSCRAMSMTP represents a part of the test server for SCRAM-based SMTP authentication. +// It does not do any acutal computation of the challenges but verifies that the expected +// fields are present. We have actual real authentication tests for all SCRAM modes in the +// go-mail client_test.go +type testSCRAMSMTP struct { + authMechanism string + nonce string + h func() hash.Hash + tlsServer bool +} + +func (s *testSCRAMSMTP) handleSCRAMAuth(conn net.Conn) { + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + writeLine := func(data string) error { + _, err := writer.WriteString(data + "\r\n") + if err != nil { + return fmt.Errorf("unable to write line: %w", err) + } + return writer.Flush() + } + var authMsg string + + data, err := reader.ReadString('\n') + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + data = strings.TrimSpace(data) + decodedMessage, err := base64.StdEncoding.DecodeString(data) + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + splits := strings.Split(string(decodedMessage), ",") + if len(splits) != 4 { + _ = writeLine("535 Authentication failed - expected 4 parts") + return + } + if !s.tlsServer && splits[0] != "n" { + _ = writeLine("535 Authentication failed - expected n to be in the first part") + return + } + if s.tlsServer && !strings.HasPrefix(splits[0], "p=") { + _ = writeLine("535 Authentication failed - expected p= to be in the first part") + return + } + if splits[2] != "n=username" { + _ = writeLine("535 Authentication failed - expected n=username to be in the third part") + return + } + if !strings.HasPrefix(splits[3], "r=") { + _ = writeLine("535 Authentication failed - expected r= to be in the fourth part") + return + } + authMsg = splits[2] + "," + splits[3] + + clientNonce := s.extractNonce(string(decodedMessage)) + if clientNonce == "" { + _ = writeLine("535 Authentication failed") + return + } + + s.nonce = clientNonce + "server_nonce" + serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=4096", s.nonce, + base64.StdEncoding.EncodeToString([]byte("salt"))) + _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage)))) + authMsg = authMsg + "," + serverFirstMessage + + data, err = reader.ReadString('\n') + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + data = strings.TrimSpace(data) + decodedFinalMessage, err := base64.StdEncoding.DecodeString(data) + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + splits = strings.Split(string(decodedFinalMessage), ",") + + if !s.tlsServer && splits[0] != "c=biws" { + _ = writeLine("535 Authentication failed - expected c=biws to be in the first part") + return + } + if s.tlsServer { + if !strings.HasPrefix(splits[0], "c=") { + _ = writeLine("535 Authentication failed - expected c= to be in the first part") + return + } + channelBind, err := base64.StdEncoding.DecodeString(splits[0][2:]) + if err != nil { + _ = writeLine("535 Authentication failed - base64 channel bind is not valid - " + err.Error()) + return + } + if !strings.HasPrefix(string(channelBind), "p=") { + _ = writeLine("535 Authentication failed - expected channel binding to start with p=-") + return + } + cbType := string(channelBind[2:]) + if !strings.HasPrefix(cbType, "tls-unique") && !strings.HasPrefix(cbType, "tls-exporter") { + _ = writeLine("535 Authentication failed - expected channel binding type tls-unique or tls-exporter") + return + } + } + + if !strings.HasPrefix(splits[1], "r=") { + _ = writeLine("535 Authentication failed - expected r to be in the second part") + return + } + if !strings.Contains(splits[1], "server_nonce") { + _ = writeLine("535 Authentication failed - expected server_nonce to be in the second part") + return + } + if !strings.HasPrefix(splits[2], "p=") { + _ = writeLine("535 Authentication failed - expected p to be in the third part") + return + } + + authMsg = authMsg + "," + splits[0] + "," + splits[1] + saltedPwd := pbkdf2.Key([]byte("password"), []byte("salt"), 4096, s.h().Size(), s.h) + mac := hmac.New(s.h, saltedPwd) + mac.Write([]byte("Server Key")) + skey := mac.Sum(nil) + mac.Reset() + + mac = hmac.New(s.h, skey) + mac.Write([]byte(authMsg)) + ssig := mac.Sum(nil) + mac.Reset() + + serverFinalMessage := fmt.Sprintf("v=%s", base64.StdEncoding.EncodeToString(ssig)) + _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFinalMessage)))) + + _, err = reader.ReadString('\n') + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + + _ = writeLine("235 Authentication successful") +} + +func (s *testSCRAMSMTP) extractNonce(message string) string { + parts := strings.Split(message, ",") + for _, part := range parts { + if strings.HasPrefix(part, "r=") { + return part[2:] + } + } + return "" +}