mirror of
https://github.com/wneessen/go-mail.git
synced 2024-12-18 17:00:38 +01:00
Add SCRAM-SHA authentication tests for SMTP
Introduce new unit tests to verify SCRAM-SHA-1 and SCRAM-SHA-256 authentication for the SMTP client. These tests cover both successful and failing authentication cases, and include a mock SMTP server to facilitate testing.
This commit is contained in:
parent
a8e89a1258
commit
03062c5183
1 changed files with 288 additions and 0 deletions
|
@ -18,6 +18,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -368,6 +369,106 @@ func TestXOAuth2Error(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthSCRAMSHA1_OK(t *testing.T) {
|
||||||
|
hostname := "127.0.0.1"
|
||||||
|
port := "2585"
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
startSMTPServer(false, hostname, port)
|
||||||
|
}()
|
||||||
|
time.Sleep(time.Millisecond * 500)
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to dial server: %v", err)
|
||||||
|
}
|
||||||
|
client, err := NewClient(conn, hostname)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create client: %v", err)
|
||||||
|
}
|
||||||
|
if err = client.Hello(hostname); err != nil {
|
||||||
|
t.Errorf("failed to send HELO: %v", err)
|
||||||
|
}
|
||||||
|
if err = client.Auth(ScramSHA1Auth("username", "password")); err != nil {
|
||||||
|
t.Errorf("failed to authenticate: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthSCRAMSHA256_OK(t *testing.T) {
|
||||||
|
hostname := "127.0.0.1"
|
||||||
|
port := "2586"
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
startSMTPServer(false, hostname, port)
|
||||||
|
}()
|
||||||
|
time.Sleep(time.Millisecond * 500)
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to dial server: %v", err)
|
||||||
|
}
|
||||||
|
client, err := NewClient(conn, hostname)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create client: %v", err)
|
||||||
|
}
|
||||||
|
if err = client.Hello(hostname); err != nil {
|
||||||
|
t.Errorf("failed to send HELO: %v", err)
|
||||||
|
}
|
||||||
|
if err = client.Auth(ScramSHA256Auth("username", "password")); err != nil {
|
||||||
|
t.Errorf("failed to authenticate: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthSCRAMSHA1_fail(t *testing.T) {
|
||||||
|
hostname := "127.0.0.1"
|
||||||
|
port := "2587"
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
startSMTPServer(true, hostname, port)
|
||||||
|
}()
|
||||||
|
time.Sleep(time.Millisecond * 500)
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to dial server: %v", err)
|
||||||
|
}
|
||||||
|
client, err := NewClient(conn, hostname)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create client: %v", err)
|
||||||
|
}
|
||||||
|
if err = client.Hello(hostname); err != nil {
|
||||||
|
t.Errorf("failed to send HELO: %v", err)
|
||||||
|
}
|
||||||
|
if err = client.Auth(ScramSHA1Auth("username", "password")); err == nil {
|
||||||
|
t.Errorf("expected auth error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthSCRAMSHA256_fail(t *testing.T) {
|
||||||
|
hostname := "127.0.0.1"
|
||||||
|
port := "2588"
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
startSMTPServer(true, hostname, port)
|
||||||
|
}()
|
||||||
|
time.Sleep(time.Millisecond * 500)
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to dial server: %v", err)
|
||||||
|
}
|
||||||
|
client, err := NewClient(conn, hostname)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create client: %v", err)
|
||||||
|
}
|
||||||
|
if err = client.Hello(hostname); err != nil {
|
||||||
|
t.Errorf("failed to send HELO: %v", err)
|
||||||
|
}
|
||||||
|
if err = client.Auth(ScramSHA256Auth("username", "password")); err == nil {
|
||||||
|
t.Errorf("expected auth error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Issue 17794: don't send a trailing space on AUTH command when there's no password.
|
// Issue 17794: don't send a trailing space on AUTH command when there's no password.
|
||||||
func TestClientAuthTrimSpace(t *testing.T) {
|
func TestClientAuthTrimSpace(t *testing.T) {
|
||||||
server := "220 hello world\r\n" +
|
server := "220 hello world\r\n" +
|
||||||
|
@ -1541,3 +1642,190 @@ func SkipFlaky(t testing.TB, issue int) {
|
||||||
t.Skipf("skipping known flaky test without the -flaky flag; see golang.org/issue/%d", issue)
|
t.Skipf("skipping known flaky test without the -flaky flag; see golang.org/issue/%d", issue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testSCRAMSMTPServer represents a test server for SCRAM-based SMTP authentication.
|
||||||
|
// It does not do any acutal computation of the challanges but verifies that the expected
|
||||||
|
// fields are present. We have actual real authentication tests for all SCRAM modes in the
|
||||||
|
// go-mail client_test.go
|
||||||
|
type testSCRAMSMTPServer struct {
|
||||||
|
authMechanism string
|
||||||
|
nonce string
|
||||||
|
hostname string
|
||||||
|
port string
|
||||||
|
shouldFail bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) {
|
||||||
|
defer func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
reader := bufio.NewReader(conn)
|
||||||
|
writer := bufio.NewWriter(conn)
|
||||||
|
writeLine := func(data string) error {
|
||||||
|
_, err := writer.WriteString(data + "\r\n")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to write line: %w", err)
|
||||||
|
}
|
||||||
|
return writer.Flush()
|
||||||
|
}
|
||||||
|
writeOK := func() {
|
||||||
|
_ = writeLine("250 2.0.0 OK")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeLine("220 go-mail test server ready ESMTP"); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data = strings.TrimSpace(data)
|
||||||
|
if strings.HasPrefix(data, "EHLO") {
|
||||||
|
_ = writeLine(fmt.Sprintf("250-%s", s.hostname))
|
||||||
|
_ = writeLine("250-AUTH SCRAM-SHA-1 SCRAM-SHA-256")
|
||||||
|
writeOK()
|
||||||
|
} else {
|
||||||
|
_ = writeLine("500 Invalid command")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
data, err = reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed to read data: %v", err)
|
||||||
|
}
|
||||||
|
data = strings.TrimSpace(data)
|
||||||
|
if strings.HasPrefix(data, "AUTH") {
|
||||||
|
parts := strings.Split(data, " ")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
_ = writeLine("500 Syntax error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authMechanism := parts[1]
|
||||||
|
if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" {
|
||||||
|
_ = writeLine("504 Unrecognized authentication mechanism")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.authMechanism = authMechanism
|
||||||
|
_ = writeLine("334 ")
|
||||||
|
s.handleSCRAMAuth(conn)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
_ = writeLine("500 Invalid command")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) {
|
||||||
|
reader := bufio.NewReader(conn)
|
||||||
|
writer := bufio.NewWriter(conn)
|
||||||
|
writeLine := func(data string) error {
|
||||||
|
_, err := writer.WriteString(data + "\r\n")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to write line: %w", err)
|
||||||
|
}
|
||||||
|
return writer.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := reader.ReadString('\n')
|
||||||
|
clientMessage := strings.TrimSpace(data)
|
||||||
|
decodedMessage, err := base64.StdEncoding.DecodeString(clientMessage)
|
||||||
|
if err != nil {
|
||||||
|
_ = writeLine("535 Authentication failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
splits := strings.Split(string(decodedMessage), ",")
|
||||||
|
if len(splits) != 4 {
|
||||||
|
_ = writeLine("535 Authentication failed - expected 4 parts")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if splits[0] != "n" {
|
||||||
|
_ = writeLine("535 Authentication failed - expected n to be in the first part")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if splits[2] != "n=username" {
|
||||||
|
_ = writeLine("535 Authentication failed - expected n=username to be in the third part")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(splits[3], "r=") {
|
||||||
|
_ = writeLine("535 Authentication failed - expected r= to be in the fourth part")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientNonce := s.extractNonce(string(decodedMessage))
|
||||||
|
if clientNonce == "" {
|
||||||
|
_ = writeLine("535 Authentication failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.nonce = clientNonce + "server_nonce"
|
||||||
|
serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=0", s.nonce, "salt")
|
||||||
|
_ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage))))
|
||||||
|
|
||||||
|
data, err = reader.ReadString('\n')
|
||||||
|
clientFinalMessage := strings.TrimSpace(data)
|
||||||
|
decodedFinalMessage, err := base64.StdEncoding.DecodeString(clientFinalMessage)
|
||||||
|
if err != nil {
|
||||||
|
_ = writeLine("535 Authentication failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
splits = strings.Split(string(decodedFinalMessage), ",")
|
||||||
|
if splits[0] != "c=biws" {
|
||||||
|
_ = writeLine("535 Authentication failed - expected c=biws to be in the first part")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(splits[1], "r=") {
|
||||||
|
_ = writeLine("535 Authentication failed - expected r to be in the second part")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.Contains(splits[1], "server_nonce") {
|
||||||
|
_ = writeLine("535 Authentication failed - expected server_nonce to be in the second part")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(splits[2], "p=") {
|
||||||
|
_ = writeLine("535 Authentication failed - expected p to be in the third part")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.shouldFail {
|
||||||
|
_ = writeLine("535 Authentication failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = writeLine("235 Authentication successful")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testSCRAMSMTPServer) extractNonce(message string) string {
|
||||||
|
parts := strings.Split(message, ",")
|
||||||
|
for _, part := range parts {
|
||||||
|
if strings.HasPrefix(part, "r=") {
|
||||||
|
return part[2:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func startSMTPServer(shouldFail bool, hostname, port string) {
|
||||||
|
server := &testSCRAMSMTPServer{
|
||||||
|
hostname: hostname,
|
||||||
|
port: port,
|
||||||
|
shouldFail: shouldFail,
|
||||||
|
}
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", hostname, port))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to start SMTP server: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = listener.Close()
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Failed to accept connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go server.handleConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue