Add SCRAM authentication tests to smtp package

Added comprehensive unit tests for SCRAM-SHA-1, SCRAM-SHA-256, and their PLUS variants. Implemented a test server to simulate various SCRAM authentication scenarios and validate both success and failure cases.
This commit is contained in:
Winni Neessen 2024-11-08 16:53:09 +01:00
parent c656226fd3
commit d4c6cb506c
Signed by: wneessen
GPG key ID: 385AC9889632126E

View file

@ -17,15 +17,22 @@ import (
"bufio"
"bytes"
"context"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"hash"
"io"
"net"
"strings"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/pbkdf2"
)
const (
@ -1038,6 +1045,178 @@ func TestXOAuth2Auth(t *testing.T) {
})
}
func TestScramAuth(t *testing.T) {
tests := []struct {
name string
tls bool
authString string
hash func() hash.Hash
isPlus bool
}{
{"SCRAM-SHA-1 (no TLS)", false, "SCRAM-SHA-1", sha1.New, false},
{"SCRAM-SHA-256 (no TLS)", false, "SCRAM-SHA-256", sha256.New, false},
{"SCRAM-SHA-1 (with TLS)", true, "SCRAM-SHA-1", sha1.New, false},
{"SCRAM-SHA-256 (with TLS)", true, "SCRAM-SHA-256", sha256.New, false},
{"SCRAM-SHA-1-PLUS", true, "SCRAM-SHA-1-PLUS", sha1.New, true},
{"SCRAM-SHA-256-PLUS", true, "SCRAM-SHA-256-PLUS", sha256.New, true},
}
for _, tt := range tests {
t.Run(tt.name+" succeeds on test server", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
PortAdder.Add(1)
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := fmt.Sprintf("250-AUTH %s\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8", tt.authString)
go func() {
if err := simpleSMTPServer(ctx, t, &serverProps{
TestSCRAM: true,
HashFunc: tt.hash,
FeatureSet: featureSet,
ListenPort: serverPort,
SSLListener: tt.tls,
IsSCRAMPlus: tt.isPlus,
},
); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 30)
var client *Client
switch tt.tls {
case true:
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
fmt.Printf("error creating TLS cert: %s", err)
return
}
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig)
if err != nil {
t.Fatalf("failed to dial TLS server: %v", err)
}
client, err = NewClient(conn, TestServerAddr)
case false:
var err error
client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
if err != nil {
t.Fatalf("failed to connect to test server: %s", err)
}
}
t.Cleanup(func() {
if err := client.Close(); err != nil {
t.Errorf("failed to close client connection: %s", err)
}
})
var auth Auth
switch tt.authString {
case "SCRAM-SHA-1":
auth = ScramSHA1Auth("username", "password")
case "SCRAM-SHA-256":
auth = ScramSHA256Auth("username", "password")
case "SCRAM-SHA-1-PLUS":
tlsConnState, err := client.GetTLSConnectionState()
if err != nil {
t.Fatalf("failed to get TLS connection state: %s", err)
}
auth = ScramSHA1PlusAuth("username", "password", tlsConnState)
case "SCRAM-SHA-256-PLUS":
tlsConnState, err := client.GetTLSConnectionState()
if err != nil {
t.Fatalf("failed to get TLS connection state: %s", err)
}
auth = ScramSHA256PlusAuth("username", "password", tlsConnState)
default:
t.Fatalf("unexpected auth string: %s", tt.authString)
}
if err := client.Auth(auth); err != nil {
t.Errorf("failed to authenticate to test server: %s", err)
}
})
t.Run(tt.name+" fails on test server", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
PortAdder.Add(1)
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := fmt.Sprintf("250-AUTH %s\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8", tt.authString)
go func() {
if err := simpleSMTPServer(ctx, t, &serverProps{
TestSCRAM: true,
HashFunc: tt.hash,
FeatureSet: featureSet,
ListenPort: serverPort,
SSLListener: tt.tls,
IsSCRAMPlus: tt.isPlus,
},
); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 30)
var client *Client
switch tt.tls {
case true:
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
fmt.Printf("error creating TLS cert: %s", err)
return
}
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig)
if err != nil {
t.Fatalf("failed to dial TLS server: %v", err)
}
client, err = NewClient(conn, TestServerAddr)
case false:
var err error
client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
if err != nil {
t.Fatalf("failed to connect to test server: %s", err)
}
}
var auth Auth
switch tt.authString {
case "SCRAM-SHA-1":
auth = ScramSHA1Auth("invalid", "password")
case "SCRAM-SHA-256":
auth = ScramSHA256Auth("invalid", "password")
case "SCRAM-SHA-1-PLUS":
tlsConnState, err := client.GetTLSConnectionState()
if err != nil {
t.Fatalf("failed to get TLS connection state: %s", err)
}
auth = ScramSHA1PlusAuth("invalid", "password", tlsConnState)
case "SCRAM-SHA-256-PLUS":
tlsConnState, err := client.GetTLSConnectionState()
if err != nil {
t.Fatalf("failed to get TLS connection state: %s", err)
}
auth = ScramSHA256PlusAuth("invalid", "password", tlsConnState)
default:
t.Fatalf("unexpected auth string: %s", tt.authString)
}
if err := client.Auth(auth); err == nil {
t.Error("expected authentication to fail")
}
})
}
t.Run("ScramAuth_Next with nonsense parameter", func(t *testing.T) {
auth := ScramSHA1Auth("username", "password")
_, err := auth.Next([]byte("x=nonsense"), true)
if err == nil {
t.Fatal("expected authentication to fail")
}
if !errors.Is(err, ErrUnexpectedServerResponse) {
t.Errorf("expected ErrUnexpectedServerResponse, got %s", err)
}
})
}
/*
@ -3140,8 +3319,11 @@ type serverProps struct {
FeatureSet string
ListenPort int
SSLListener bool
IsSCRAMPlus bool
IsTLS bool
SupportDSN bool
TestSCRAM bool
HashFunc func() hash.Hash
}
// simpleSMTPServer starts a simple TCP server that resonds to SMTP commands.
@ -3279,6 +3461,22 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
writeLine("535 5.7.8 Error: authentication failed")
break
}
if props.TestSCRAM {
parts := strings.Split(data, " ")
authMechanism := parts[1]
if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" &&
authMechanism != "SCRAM-SHA-1-PLUS" && authMechanism != "SCRAM-SHA-256-PLUS" {
writeLine("504 Unrecognized authentication mechanism")
break
}
scram := &testSCRAMSMTP{
tlsServer: props.IsSCRAMPlus,
h: props.HashFunc,
}
writeLine("334 ")
scram.handleSCRAMAuth(connection)
break
}
writeLine("235 2.7.0 Authentication successful")
case strings.EqualFold(data, "DATA"):
if props.FailOnDataInit {
@ -3348,3 +3546,157 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
}
}
}
// testSCRAMSMTP represents a part of the test server for SCRAM-based SMTP authentication.
// It does not do any acutal computation of the challenges but verifies that the expected
// fields are present. We have actual real authentication tests for all SCRAM modes in the
// go-mail client_test.go
type testSCRAMSMTP struct {
authMechanism string
nonce string
h func() hash.Hash
tlsServer bool
}
func (s *testSCRAMSMTP) handleSCRAMAuth(conn net.Conn) {
reader := bufio.NewReader(conn)
writer := bufio.NewWriter(conn)
writeLine := func(data string) error {
_, err := writer.WriteString(data + "\r\n")
if err != nil {
return fmt.Errorf("unable to write line: %w", err)
}
return writer.Flush()
}
var authMsg string
data, err := reader.ReadString('\n')
if err != nil {
_ = writeLine("535 Authentication failed")
return
}
data = strings.TrimSpace(data)
decodedMessage, err := base64.StdEncoding.DecodeString(data)
if err != nil {
_ = writeLine("535 Authentication failed")
return
}
splits := strings.Split(string(decodedMessage), ",")
if len(splits) != 4 {
_ = writeLine("535 Authentication failed - expected 4 parts")
return
}
if !s.tlsServer && splits[0] != "n" {
_ = writeLine("535 Authentication failed - expected n to be in the first part")
return
}
if s.tlsServer && !strings.HasPrefix(splits[0], "p=") {
_ = writeLine("535 Authentication failed - expected p= to be in the first part")
return
}
if splits[2] != "n=username" {
_ = writeLine("535 Authentication failed - expected n=username to be in the third part")
return
}
if !strings.HasPrefix(splits[3], "r=") {
_ = writeLine("535 Authentication failed - expected r= to be in the fourth part")
return
}
authMsg = splits[2] + "," + splits[3]
clientNonce := s.extractNonce(string(decodedMessage))
if clientNonce == "" {
_ = writeLine("535 Authentication failed")
return
}
s.nonce = clientNonce + "server_nonce"
serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=4096", s.nonce,
base64.StdEncoding.EncodeToString([]byte("salt")))
_ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage))))
authMsg = authMsg + "," + serverFirstMessage
data, err = reader.ReadString('\n')
if err != nil {
_ = writeLine("535 Authentication failed")
return
}
data = strings.TrimSpace(data)
decodedFinalMessage, err := base64.StdEncoding.DecodeString(data)
if err != nil {
_ = writeLine("535 Authentication failed")
return
}
splits = strings.Split(string(decodedFinalMessage), ",")
if !s.tlsServer && splits[0] != "c=biws" {
_ = writeLine("535 Authentication failed - expected c=biws to be in the first part")
return
}
if s.tlsServer {
if !strings.HasPrefix(splits[0], "c=") {
_ = writeLine("535 Authentication failed - expected c= to be in the first part")
return
}
channelBind, err := base64.StdEncoding.DecodeString(splits[0][2:])
if err != nil {
_ = writeLine("535 Authentication failed - base64 channel bind is not valid - " + err.Error())
return
}
if !strings.HasPrefix(string(channelBind), "p=") {
_ = writeLine("535 Authentication failed - expected channel binding to start with p=-")
return
}
cbType := string(channelBind[2:])
if !strings.HasPrefix(cbType, "tls-unique") && !strings.HasPrefix(cbType, "tls-exporter") {
_ = writeLine("535 Authentication failed - expected channel binding type tls-unique or tls-exporter")
return
}
}
if !strings.HasPrefix(splits[1], "r=") {
_ = writeLine("535 Authentication failed - expected r to be in the second part")
return
}
if !strings.Contains(splits[1], "server_nonce") {
_ = writeLine("535 Authentication failed - expected server_nonce to be in the second part")
return
}
if !strings.HasPrefix(splits[2], "p=") {
_ = writeLine("535 Authentication failed - expected p to be in the third part")
return
}
authMsg = authMsg + "," + splits[0] + "," + splits[1]
saltedPwd := pbkdf2.Key([]byte("password"), []byte("salt"), 4096, s.h().Size(), s.h)
mac := hmac.New(s.h, saltedPwd)
mac.Write([]byte("Server Key"))
skey := mac.Sum(nil)
mac.Reset()
mac = hmac.New(s.h, skey)
mac.Write([]byte(authMsg))
ssig := mac.Sum(nil)
mac.Reset()
serverFinalMessage := fmt.Sprintf("v=%s", base64.StdEncoding.EncodeToString(ssig))
_ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFinalMessage))))
_, err = reader.ReadString('\n')
if err != nil {
_ = writeLine("535 Authentication failed")
return
}
_ = writeLine("235 Authentication successful")
}
func (s *testSCRAMSMTP) extractNonce(message string) string {
parts := strings.Split(message, ",")
for _, part := range parts {
if strings.HasPrefix(part, "r=") {
return part[2:]
}
}
return ""
}