mirror of
https://github.com/wneessen/go-mail.git
synced 2024-12-22 18:50:37 +01:00
Add SCRAM authentication tests to smtp package
Added comprehensive unit tests for SCRAM-SHA-1, SCRAM-SHA-256, and their PLUS variants. Implemented a test server to simulate various SCRAM authentication scenarios and validate both success and failure cases.
This commit is contained in:
parent
c656226fd3
commit
d4c6cb506c
1 changed files with 352 additions and 0 deletions
|
@ -17,15 +17,22 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -1038,6 +1045,178 @@ func TestXOAuth2Auth(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestScramAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tls bool
|
||||
authString string
|
||||
hash func() hash.Hash
|
||||
isPlus bool
|
||||
}{
|
||||
{"SCRAM-SHA-1 (no TLS)", false, "SCRAM-SHA-1", sha1.New, false},
|
||||
{"SCRAM-SHA-256 (no TLS)", false, "SCRAM-SHA-256", sha256.New, false},
|
||||
{"SCRAM-SHA-1 (with TLS)", true, "SCRAM-SHA-1", sha1.New, false},
|
||||
{"SCRAM-SHA-256 (with TLS)", true, "SCRAM-SHA-256", sha256.New, false},
|
||||
{"SCRAM-SHA-1-PLUS", true, "SCRAM-SHA-1-PLUS", sha1.New, true},
|
||||
{"SCRAM-SHA-256-PLUS", true, "SCRAM-SHA-256-PLUS", sha256.New, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name+" succeeds on test server", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
PortAdder.Add(1)
|
||||
serverPort := int(TestServerPortBase + PortAdder.Load())
|
||||
featureSet := fmt.Sprintf("250-AUTH %s\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8", tt.authString)
|
||||
go func() {
|
||||
if err := simpleSMTPServer(ctx, t, &serverProps{
|
||||
TestSCRAM: true,
|
||||
HashFunc: tt.hash,
|
||||
FeatureSet: featureSet,
|
||||
ListenPort: serverPort,
|
||||
SSLListener: tt.tls,
|
||||
IsSCRAMPlus: tt.isPlus,
|
||||
},
|
||||
); err != nil {
|
||||
t.Errorf("failed to start test server: %s", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 30)
|
||||
|
||||
var client *Client
|
||||
switch tt.tls {
|
||||
case true:
|
||||
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
|
||||
if err != nil {
|
||||
fmt.Printf("error creating TLS cert: %s", err)
|
||||
return
|
||||
}
|
||||
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial TLS server: %v", err)
|
||||
}
|
||||
client, err = NewClient(conn, TestServerAddr)
|
||||
case false:
|
||||
var err error
|
||||
client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to test server: %s", err)
|
||||
}
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Errorf("failed to close client connection: %s", err)
|
||||
}
|
||||
})
|
||||
|
||||
var auth Auth
|
||||
switch tt.authString {
|
||||
case "SCRAM-SHA-1":
|
||||
auth = ScramSHA1Auth("username", "password")
|
||||
case "SCRAM-SHA-256":
|
||||
auth = ScramSHA256Auth("username", "password")
|
||||
case "SCRAM-SHA-1-PLUS":
|
||||
tlsConnState, err := client.GetTLSConnectionState()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get TLS connection state: %s", err)
|
||||
}
|
||||
auth = ScramSHA1PlusAuth("username", "password", tlsConnState)
|
||||
case "SCRAM-SHA-256-PLUS":
|
||||
tlsConnState, err := client.GetTLSConnectionState()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get TLS connection state: %s", err)
|
||||
}
|
||||
auth = ScramSHA256PlusAuth("username", "password", tlsConnState)
|
||||
default:
|
||||
t.Fatalf("unexpected auth string: %s", tt.authString)
|
||||
}
|
||||
if err := client.Auth(auth); err != nil {
|
||||
t.Errorf("failed to authenticate to test server: %s", err)
|
||||
}
|
||||
})
|
||||
t.Run(tt.name+" fails on test server", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
PortAdder.Add(1)
|
||||
serverPort := int(TestServerPortBase + PortAdder.Load())
|
||||
featureSet := fmt.Sprintf("250-AUTH %s\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8", tt.authString)
|
||||
go func() {
|
||||
if err := simpleSMTPServer(ctx, t, &serverProps{
|
||||
TestSCRAM: true,
|
||||
HashFunc: tt.hash,
|
||||
FeatureSet: featureSet,
|
||||
ListenPort: serverPort,
|
||||
SSLListener: tt.tls,
|
||||
IsSCRAMPlus: tt.isPlus,
|
||||
},
|
||||
); err != nil {
|
||||
t.Errorf("failed to start test server: %s", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 30)
|
||||
|
||||
var client *Client
|
||||
switch tt.tls {
|
||||
case true:
|
||||
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
|
||||
if err != nil {
|
||||
fmt.Printf("error creating TLS cert: %s", err)
|
||||
return
|
||||
}
|
||||
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial TLS server: %v", err)
|
||||
}
|
||||
client, err = NewClient(conn, TestServerAddr)
|
||||
case false:
|
||||
var err error
|
||||
client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to test server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
var auth Auth
|
||||
switch tt.authString {
|
||||
case "SCRAM-SHA-1":
|
||||
auth = ScramSHA1Auth("invalid", "password")
|
||||
case "SCRAM-SHA-256":
|
||||
auth = ScramSHA256Auth("invalid", "password")
|
||||
case "SCRAM-SHA-1-PLUS":
|
||||
tlsConnState, err := client.GetTLSConnectionState()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get TLS connection state: %s", err)
|
||||
}
|
||||
auth = ScramSHA1PlusAuth("invalid", "password", tlsConnState)
|
||||
case "SCRAM-SHA-256-PLUS":
|
||||
tlsConnState, err := client.GetTLSConnectionState()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get TLS connection state: %s", err)
|
||||
}
|
||||
auth = ScramSHA256PlusAuth("invalid", "password", tlsConnState)
|
||||
default:
|
||||
t.Fatalf("unexpected auth string: %s", tt.authString)
|
||||
}
|
||||
if err := client.Auth(auth); err == nil {
|
||||
t.Error("expected authentication to fail")
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Run("ScramAuth_Next with nonsense parameter", func(t *testing.T) {
|
||||
auth := ScramSHA1Auth("username", "password")
|
||||
_, err := auth.Next([]byte("x=nonsense"), true)
|
||||
if err == nil {
|
||||
t.Fatal("expected authentication to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrUnexpectedServerResponse) {
|
||||
t.Errorf("expected ErrUnexpectedServerResponse, got %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
|
||||
|
@ -3140,8 +3319,11 @@ type serverProps struct {
|
|||
FeatureSet string
|
||||
ListenPort int
|
||||
SSLListener bool
|
||||
IsSCRAMPlus bool
|
||||
IsTLS bool
|
||||
SupportDSN bool
|
||||
TestSCRAM bool
|
||||
HashFunc func() hash.Hash
|
||||
}
|
||||
|
||||
// simpleSMTPServer starts a simple TCP server that resonds to SMTP commands.
|
||||
|
@ -3279,6 +3461,22 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
|
|||
writeLine("535 5.7.8 Error: authentication failed")
|
||||
break
|
||||
}
|
||||
if props.TestSCRAM {
|
||||
parts := strings.Split(data, " ")
|
||||
authMechanism := parts[1]
|
||||
if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" &&
|
||||
authMechanism != "SCRAM-SHA-1-PLUS" && authMechanism != "SCRAM-SHA-256-PLUS" {
|
||||
writeLine("504 Unrecognized authentication mechanism")
|
||||
break
|
||||
}
|
||||
scram := &testSCRAMSMTP{
|
||||
tlsServer: props.IsSCRAMPlus,
|
||||
h: props.HashFunc,
|
||||
}
|
||||
writeLine("334 ")
|
||||
scram.handleSCRAMAuth(connection)
|
||||
break
|
||||
}
|
||||
writeLine("235 2.7.0 Authentication successful")
|
||||
case strings.EqualFold(data, "DATA"):
|
||||
if props.FailOnDataInit {
|
||||
|
@ -3348,3 +3546,157 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testSCRAMSMTP represents a part of the test server for SCRAM-based SMTP authentication.
|
||||
// It does not do any acutal computation of the challenges but verifies that the expected
|
||||
// fields are present. We have actual real authentication tests for all SCRAM modes in the
|
||||
// go-mail client_test.go
|
||||
type testSCRAMSMTP struct {
|
||||
authMechanism string
|
||||
nonce string
|
||||
h func() hash.Hash
|
||||
tlsServer bool
|
||||
}
|
||||
|
||||
func (s *testSCRAMSMTP) handleSCRAMAuth(conn net.Conn) {
|
||||
reader := bufio.NewReader(conn)
|
||||
writer := bufio.NewWriter(conn)
|
||||
writeLine := func(data string) error {
|
||||
_, err := writer.WriteString(data + "\r\n")
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to write line: %w", err)
|
||||
}
|
||||
return writer.Flush()
|
||||
}
|
||||
var authMsg string
|
||||
|
||||
data, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
_ = writeLine("535 Authentication failed")
|
||||
return
|
||||
}
|
||||
data = strings.TrimSpace(data)
|
||||
decodedMessage, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
_ = writeLine("535 Authentication failed")
|
||||
return
|
||||
}
|
||||
splits := strings.Split(string(decodedMessage), ",")
|
||||
if len(splits) != 4 {
|
||||
_ = writeLine("535 Authentication failed - expected 4 parts")
|
||||
return
|
||||
}
|
||||
if !s.tlsServer && splits[0] != "n" {
|
||||
_ = writeLine("535 Authentication failed - expected n to be in the first part")
|
||||
return
|
||||
}
|
||||
if s.tlsServer && !strings.HasPrefix(splits[0], "p=") {
|
||||
_ = writeLine("535 Authentication failed - expected p= to be in the first part")
|
||||
return
|
||||
}
|
||||
if splits[2] != "n=username" {
|
||||
_ = writeLine("535 Authentication failed - expected n=username to be in the third part")
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(splits[3], "r=") {
|
||||
_ = writeLine("535 Authentication failed - expected r= to be in the fourth part")
|
||||
return
|
||||
}
|
||||
authMsg = splits[2] + "," + splits[3]
|
||||
|
||||
clientNonce := s.extractNonce(string(decodedMessage))
|
||||
if clientNonce == "" {
|
||||
_ = writeLine("535 Authentication failed")
|
||||
return
|
||||
}
|
||||
|
||||
s.nonce = clientNonce + "server_nonce"
|
||||
serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=4096", s.nonce,
|
||||
base64.StdEncoding.EncodeToString([]byte("salt")))
|
||||
_ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage))))
|
||||
authMsg = authMsg + "," + serverFirstMessage
|
||||
|
||||
data, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
_ = writeLine("535 Authentication failed")
|
||||
return
|
||||
}
|
||||
data = strings.TrimSpace(data)
|
||||
decodedFinalMessage, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
_ = writeLine("535 Authentication failed")
|
||||
return
|
||||
}
|
||||
splits = strings.Split(string(decodedFinalMessage), ",")
|
||||
|
||||
if !s.tlsServer && splits[0] != "c=biws" {
|
||||
_ = writeLine("535 Authentication failed - expected c=biws to be in the first part")
|
||||
return
|
||||
}
|
||||
if s.tlsServer {
|
||||
if !strings.HasPrefix(splits[0], "c=") {
|
||||
_ = writeLine("535 Authentication failed - expected c= to be in the first part")
|
||||
return
|
||||
}
|
||||
channelBind, err := base64.StdEncoding.DecodeString(splits[0][2:])
|
||||
if err != nil {
|
||||
_ = writeLine("535 Authentication failed - base64 channel bind is not valid - " + err.Error())
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(string(channelBind), "p=") {
|
||||
_ = writeLine("535 Authentication failed - expected channel binding to start with p=-")
|
||||
return
|
||||
}
|
||||
cbType := string(channelBind[2:])
|
||||
if !strings.HasPrefix(cbType, "tls-unique") && !strings.HasPrefix(cbType, "tls-exporter") {
|
||||
_ = writeLine("535 Authentication failed - expected channel binding type tls-unique or tls-exporter")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(splits[1], "r=") {
|
||||
_ = writeLine("535 Authentication failed - expected r to be in the second part")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(splits[1], "server_nonce") {
|
||||
_ = writeLine("535 Authentication failed - expected server_nonce to be in the second part")
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(splits[2], "p=") {
|
||||
_ = writeLine("535 Authentication failed - expected p to be in the third part")
|
||||
return
|
||||
}
|
||||
|
||||
authMsg = authMsg + "," + splits[0] + "," + splits[1]
|
||||
saltedPwd := pbkdf2.Key([]byte("password"), []byte("salt"), 4096, s.h().Size(), s.h)
|
||||
mac := hmac.New(s.h, saltedPwd)
|
||||
mac.Write([]byte("Server Key"))
|
||||
skey := mac.Sum(nil)
|
||||
mac.Reset()
|
||||
|
||||
mac = hmac.New(s.h, skey)
|
||||
mac.Write([]byte(authMsg))
|
||||
ssig := mac.Sum(nil)
|
||||
mac.Reset()
|
||||
|
||||
serverFinalMessage := fmt.Sprintf("v=%s", base64.StdEncoding.EncodeToString(ssig))
|
||||
_ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFinalMessage))))
|
||||
|
||||
_, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
_ = writeLine("535 Authentication failed")
|
||||
return
|
||||
}
|
||||
|
||||
_ = writeLine("235 Authentication successful")
|
||||
}
|
||||
|
||||
func (s *testSCRAMSMTP) extractNonce(message string) string {
|
||||
parts := strings.Split(message, ",")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(part, "r=") {
|
||||
return part[2:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue