Add support for SCRAM-SHA-1-PLUS and SCRAM-SHA-256-PLUS

Extended SMTP tests to include SCRAM-SHA-1-PLUS and SCRAM-SHA-256-PLUS authentication mechanisms. Adjusted the `startSMTPServer` function to accept a hashing function and modified the server logic to handle TLS channel binding.
This commit is contained in:
Winni Neessen 2024-10-04 18:31:58 +02:00
parent 84f562554c
commit 711ce2ac65
Signed by: wneessen
GPG key ID: 385AC9889632126E

View file

@ -16,11 +16,15 @@ package smtp
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"flag" "flag"
"fmt" "fmt"
"hash"
"io" "io"
"net" "net"
"net/textproto" "net/textproto"
@ -30,6 +34,8 @@ import (
"testing" "testing"
"time" "time"
"golang.org/x/crypto/pbkdf2"
"github.com/wneessen/go-mail/log" "github.com/wneessen/go-mail/log"
) )
@ -374,7 +380,7 @@ func TestAuthSCRAMSHA1_OK(t *testing.T) {
port := "2585" port := "2585"
go func() { go func() {
startSMTPServer(false, hostname, port) startSMTPServer(false, hostname, port, sha1.New)
}() }()
time.Sleep(time.Millisecond * 500) time.Sleep(time.Millisecond * 500)
@ -399,7 +405,7 @@ func TestAuthSCRAMSHA256_OK(t *testing.T) {
port := "2586" port := "2586"
go func() { go func() {
startSMTPServer(false, hostname, port) startSMTPServer(false, hostname, port, sha256.New)
}() }()
time.Sleep(time.Millisecond * 500) time.Sleep(time.Millisecond * 500)
@ -419,12 +425,80 @@ func TestAuthSCRAMSHA256_OK(t *testing.T) {
} }
} }
func TestAuthSCRAMSHA1PLUS_OK(t *testing.T) {
hostname := "127.0.0.1"
port := "2590"
go func() {
startSMTPServer(true, hostname, port, sha1.New)
}()
time.Sleep(time.Millisecond * 500)
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
fmt.Printf("error creating TLS cert: %s", err)
return
}
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig)
if err != nil {
t.Errorf("failed to dial server: %v", err)
}
client, err := NewClient(conn, hostname)
if err != nil {
t.Errorf("failed to create client: %v", err)
}
if err = client.Hello(hostname); err != nil {
t.Errorf("failed to send HELO: %v", err)
}
tlsConnState := conn.ConnectionState()
if err = client.Auth(ScramSHA1PlusAuth("username", "password", &tlsConnState)); err != nil {
t.Errorf("failed to authenticate: %v", err)
}
}
func TestAuthSCRAMSHA256PLUS_OK(t *testing.T) {
hostname := "127.0.0.1"
port := "2591"
go func() {
startSMTPServer(true, hostname, port, sha256.New)
}()
time.Sleep(time.Millisecond * 500)
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
fmt.Printf("error creating TLS cert: %s", err)
return
}
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig)
if err != nil {
t.Errorf("failed to dial server: %v", err)
}
client, err := NewClient(conn, hostname)
if err != nil {
t.Errorf("failed to create client: %v", err)
}
if err = client.Hello(hostname); err != nil {
t.Errorf("failed to send HELO: %v", err)
}
tlsConnState := conn.ConnectionState()
if err = client.Auth(ScramSHA256PlusAuth("username", "password", &tlsConnState)); err != nil {
t.Errorf("failed to authenticate: %v", err)
}
}
func TestAuthSCRAMSHA1_fail(t *testing.T) { func TestAuthSCRAMSHA1_fail(t *testing.T) {
hostname := "127.0.0.1" hostname := "127.0.0.1"
port := "2587" port := "2587"
go func() { go func() {
startSMTPServer(true, hostname, port) startSMTPServer(false, hostname, port, sha1.New)
}() }()
time.Sleep(time.Millisecond * 500) time.Sleep(time.Millisecond * 500)
@ -439,7 +513,7 @@ func TestAuthSCRAMSHA1_fail(t *testing.T) {
if err = client.Hello(hostname); err != nil { if err = client.Hello(hostname); err != nil {
t.Errorf("failed to send HELO: %v", err) t.Errorf("failed to send HELO: %v", err)
} }
if err = client.Auth(ScramSHA1Auth("username", "password")); err == nil { if err = client.Auth(ScramSHA1Auth("username", "invalid")); err == nil {
t.Errorf("expected auth error, got nil") t.Errorf("expected auth error, got nil")
} }
} }
@ -449,7 +523,7 @@ func TestAuthSCRAMSHA256_fail(t *testing.T) {
port := "2588" port := "2588"
go func() { go func() {
startSMTPServer(true, hostname, port) startSMTPServer(false, hostname, port, sha256.New)
}() }()
time.Sleep(time.Millisecond * 500) time.Sleep(time.Millisecond * 500)
@ -464,7 +538,73 @@ func TestAuthSCRAMSHA256_fail(t *testing.T) {
if err = client.Hello(hostname); err != nil { if err = client.Hello(hostname); err != nil {
t.Errorf("failed to send HELO: %v", err) t.Errorf("failed to send HELO: %v", err)
} }
if err = client.Auth(ScramSHA256Auth("username", "password")); err == nil { if err = client.Auth(ScramSHA256Auth("username", "invalid")); err == nil {
t.Errorf("expected auth error, got nil")
}
}
func TestAuthSCRAMSHA1PLUS_fail(t *testing.T) {
hostname := "127.0.0.1"
port := "2592"
go func() {
startSMTPServer(true, hostname, port, sha1.New)
}()
time.Sleep(time.Millisecond * 500)
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
fmt.Printf("error creating TLS cert: %s", err)
return
}
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig)
if err != nil {
t.Errorf("failed to dial server: %v", err)
}
client, err := NewClient(conn, hostname)
if err != nil {
t.Errorf("failed to create client: %v", err)
}
if err = client.Hello(hostname); err != nil {
t.Errorf("failed to send HELO: %v", err)
}
tlsConnState := conn.ConnectionState()
if err = client.Auth(ScramSHA1PlusAuth("username", "invalid", &tlsConnState)); err == nil {
t.Errorf("expected auth error, got nil")
}
}
func TestAuthSCRAMSHA256PLUS_fail(t *testing.T) {
hostname := "127.0.0.1"
port := "2593"
go func() {
startSMTPServer(true, hostname, port, sha1.New)
}()
time.Sleep(time.Millisecond * 500)
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
fmt.Printf("error creating TLS cert: %s", err)
return
}
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig)
if err != nil {
t.Errorf("failed to dial server: %v", err)
}
client, err := NewClient(conn, hostname)
if err != nil {
t.Errorf("failed to create client: %v", err)
}
if err = client.Hello(hostname); err != nil {
t.Errorf("failed to send HELO: %v", err)
}
tlsConnState := conn.ConnectionState()
if err = client.Auth(ScramSHA256PlusAuth("username", "invalid", &tlsConnState)); err == nil {
t.Errorf("expected auth error, got nil") t.Errorf("expected auth error, got nil")
} }
} }
@ -1652,7 +1792,8 @@ type testSCRAMSMTPServer struct {
nonce string nonce string
hostname string hostname string
port string port string
shouldFail bool tlsServer bool
h func() hash.Hash
} }
func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) { func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) {
@ -1705,7 +1846,8 @@ func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) {
} }
authMechanism := parts[1] authMechanism := parts[1]
if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" { if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" &&
authMechanism != "SCRAM-SHA-1-PLUS" && authMechanism != "SCRAM-SHA-256-PLUS" {
_ = writeLine("504 Unrecognized authentication mechanism") _ = writeLine("504 Unrecognized authentication mechanism")
return return
} }
@ -1729,6 +1871,7 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) {
} }
return writer.Flush() return writer.Flush()
} }
var authMsg string
data, err := reader.ReadString('\n') data, err := reader.ReadString('\n')
if err != nil { if err != nil {
@ -1746,10 +1889,14 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) {
_ = writeLine("535 Authentication failed - expected 4 parts") _ = writeLine("535 Authentication failed - expected 4 parts")
return return
} }
if splits[0] != "n" { if !s.tlsServer && splits[0] != "n" {
_ = writeLine("535 Authentication failed - expected n to be in the first part") _ = writeLine("535 Authentication failed - expected n to be in the first part")
return return
} }
if s.tlsServer && !strings.HasPrefix(splits[0], "p=") {
_ = writeLine("535 Authentication failed - expected p= to be in the first part")
return
}
if splits[2] != "n=username" { if splits[2] != "n=username" {
_ = writeLine("535 Authentication failed - expected n=username to be in the third part") _ = writeLine("535 Authentication failed - expected n=username to be in the third part")
return return
@ -1758,6 +1905,8 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) {
_ = writeLine("535 Authentication failed - expected r= to be in the fourth part") _ = writeLine("535 Authentication failed - expected r= to be in the fourth part")
return return
} }
authMsg = splits[2] + "," + splits[3]
clientNonce := s.extractNonce(string(decodedMessage)) clientNonce := s.extractNonce(string(decodedMessage))
if clientNonce == "" { if clientNonce == "" {
_ = writeLine("535 Authentication failed") _ = writeLine("535 Authentication failed")
@ -1765,8 +1914,10 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) {
} }
s.nonce = clientNonce + "server_nonce" s.nonce = clientNonce + "server_nonce"
serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=0", s.nonce, "salt") serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=4096", s.nonce,
base64.StdEncoding.EncodeToString([]byte("salt")))
_ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage)))) _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage))))
authMsg = authMsg + "," + serverFirstMessage
data, err = reader.ReadString('\n') data, err = reader.ReadString('\n')
if err != nil { if err != nil {
@ -1780,10 +1931,32 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) {
return return
} }
splits = strings.Split(string(decodedFinalMessage), ",") splits = strings.Split(string(decodedFinalMessage), ",")
if splits[0] != "c=biws" {
if !s.tlsServer && splits[0] != "c=biws" {
_ = writeLine("535 Authentication failed - expected c=biws to be in the first part") _ = writeLine("535 Authentication failed - expected c=biws to be in the first part")
return return
} }
if s.tlsServer {
if !strings.HasPrefix(splits[0], "c=") {
_ = writeLine("535 Authentication failed - expected c= to be in the first part")
return
}
channelBind, err := base64.StdEncoding.DecodeString(splits[0][2:])
if err != nil {
_ = writeLine("535 Authentication failed - base64 channel bind is not valid - " + err.Error())
return
}
if !strings.HasPrefix(string(channelBind), "p=") {
_ = writeLine("535 Authentication failed - expected channel binding to start with p=-")
return
}
cbType := string(channelBind[2:])
if !strings.HasPrefix(cbType, "tls-unique") && !strings.HasPrefix(cbType, "tls-exporter") {
_ = writeLine("535 Authentication failed - expected channel binding type tls-unique or tls-exporter")
return
}
}
if !strings.HasPrefix(splits[1], "r=") { if !strings.HasPrefix(splits[1], "r=") {
_ = writeLine("535 Authentication failed - expected r to be in the second part") _ = writeLine("535 Authentication failed - expected r to be in the second part")
return return
@ -1797,10 +1970,27 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) {
return return
} }
if s.shouldFail { authMsg = authMsg + "," + splits[0] + "," + splits[1]
saltedPwd := pbkdf2.Key([]byte("password"), []byte("salt"), 4096, s.h().Size(), s.h)
mac := hmac.New(s.h, saltedPwd)
mac.Write([]byte("Server Key"))
skey := mac.Sum(nil)
mac.Reset()
mac = hmac.New(s.h, skey)
mac.Write([]byte(authMsg))
ssig := mac.Sum(nil)
mac.Reset()
serverFinalMessage := fmt.Sprintf("v=%s", base64.StdEncoding.EncodeToString(ssig))
_ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFinalMessage))))
_, err = reader.ReadString('\n')
if err != nil {
_ = writeLine("535 Authentication failed") _ = writeLine("535 Authentication failed")
return return
} }
_ = writeLine("235 Authentication successful") _ = writeLine("235 Authentication successful")
} }
@ -1814,11 +2004,12 @@ func (s *testSCRAMSMTPServer) extractNonce(message string) string {
return "" return ""
} }
func startSMTPServer(shouldFail bool, hostname, port string) { func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash) {
server := &testSCRAMSMTPServer{ server := &testSCRAMSMTPServer{
hostname: hostname, hostname: hostname,
port: port, port: port,
shouldFail: shouldFail, tlsServer: tlsServer,
h: h,
} }
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", hostname, port)) listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", hostname, port))
if err != nil { if err != nil {
@ -1827,12 +2018,23 @@ func startSMTPServer(shouldFail bool, hostname, port string) {
defer func() { defer func() {
_ = listener.Close() _ = listener.Close()
}() }()
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
fmt.Printf("error creating TLS cert: %s", err)
return
}
tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}}
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
fmt.Printf("Failed to accept connection: %v", err) fmt.Printf("Failed to accept connection: %v", err)
continue continue
} }
if server.tlsServer {
conn = tls.Server(conn, &tlsConfig)
}
go server.handleConnection(conn) go server.handleConnection(conn)
} }
} }