Refactor and streamline authentication tests

Improved the structure and readability of the authentication tests by using subtests for each scenario, ensuring better isolation and clearer failure reporting. Removed unnecessary imports and redundant code, reducing complexity and enhancing maintainability.
This commit is contained in:
Winni Neessen 2024-11-01 19:22:28 +01:00
parent a3fe2f88d5
commit 99c4378107
Signed by: wneessen
GPG key ID: 385AC9889632126E

View file

@ -14,29 +14,9 @@
package smtp
import (
"bufio"
"bytes"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"flag"
"fmt"
"hash"
"io"
"net"
"net/textproto"
"os"
"runtime"
"strings"
"errors"
"testing"
"time"
"golang.org/x/crypto/pbkdf2"
"github.com/wneessen/go-mail/log"
)
type authTest struct {
@ -180,28 +160,33 @@ var authTests = []authTest{
}
func TestAuth(t *testing.T) {
testLoop:
for i, test := range authTests {
name, resp, err := test.auth.Start(&ServerInfo{"testserver", true, nil})
if name != test.name {
t.Errorf("#%d got name %s, expected %s", i, name, test.name)
t.Run("Auth for all supported auth methods", func(t *testing.T) {
for i, tt := range authTests {
t.Run(tt.name, func(t *testing.T) {
name, resp, err := tt.auth.Start(&ServerInfo{"testserver", true, nil})
if name != tt.name {
t.Errorf("test #%d got name %s, expected %s", i, name, tt.name)
}
if !bytes.Equal(resp, []byte(test.responses[0])) {
t.Errorf("#%d got response %s, expected %s", i, resp, test.responses[0])
if len(tt.responses) <= 0 {
t.Fatalf("test #%d got no responses, expected at least one", i)
}
if !bytes.Equal(resp, []byte(tt.responses[0])) {
t.Errorf("#%d got response %s, expected %s", i, resp, tt.responses[0])
}
if err != nil {
t.Errorf("#%d error: %s", i, err)
}
for j := range test.challenges {
challenge := []byte(test.challenges[j])
expected := []byte(test.responses[j+1])
sf := test.sf[j]
resp, err := test.auth.Next(challenge, true)
testLoop:
for j := range tt.challenges {
challenge := []byte(tt.challenges[j])
expected := []byte(tt.responses[j+1])
sf := tt.sf[j]
resp, err := tt.auth.Next(challenge, true)
if err != nil && !sf {
t.Errorf("#%d error: %s", i, err)
continue testLoop
}
if test.hasNonce {
if tt.hasNonce {
if !bytes.HasPrefix(resp, expected) {
t.Errorf("#%d got response: %s, expected response to start with: %s", i, resp, expected)
}
@ -211,60 +196,116 @@ testLoop:
t.Errorf("#%d got %s, expected %s", i, resp, expected)
continue testLoop
}
_, err = test.auth.Next([]byte("2.7.0 Authentication successful"), false)
_, err = tt.auth.Next([]byte("2.7.0 Authentication successful"), false)
if err != nil {
t.Errorf("#%d success message error: %s", i, err)
}
}
})
}
})
}
func TestAuthPlain(t *testing.T) {
func TestPlainAuth(t *testing.T) {
tests := []struct {
name string
authName string
server *ServerInfo
err string
shouldFail bool
wantErr error
}{
{
name: "PLAIN auth succeeds",
authName: "servername",
server: &ServerInfo{Name: "servername", TLS: true},
shouldFail: false,
},
{
// OK to use PlainAuth on localhost without TLS
name: "PLAIN on localhost is allowed to go unencrypted",
authName: "localhost",
server: &ServerInfo{Name: "localhost", TLS: false},
shouldFail: false,
},
{
// NOT OK on non-localhost, even if server says PLAIN is OK.
// (We don't know that the server is the real server.)
name: "PLAIN on non-localhost is not allowed to go unencrypted",
authName: "servername",
server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}},
err: "unencrypted connection",
shouldFail: true,
wantErr: ErrUnencrypted,
},
{
name: "PLAIN on non-localhost with no PLAIN announcement, is not allowed to go unencrypted",
authName: "servername",
server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}},
err: "unencrypted connection",
shouldFail: true,
wantErr: ErrUnencrypted,
},
{
name: "PLAIN with wrong hostname",
authName: "servername",
server: &ServerInfo{Name: "attacker", TLS: true},
err: "wrong host name",
shouldFail: true,
wantErr: ErrWrongHostname,
},
}
for i, tt := range tests {
auth := PlainAuth("foo", "bar", "baz", tt.authName, false)
_, _, err := auth.Start(tt.server)
got := ""
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
identity := "foo"
user := "toni.tester@example.com"
pass := "v3ryS3Cur3P4ssw0rd"
auth := PlainAuth(identity, user, pass, tt.authName, false)
method, resp, err := auth.Start(tt.server)
if err != nil && !tt.shouldFail {
t.Errorf("plain authentication failed: %s", err)
}
if err == nil && tt.shouldFail {
t.Error("plain authentication was expected to fail")
}
if tt.wantErr != nil {
if !errors.Is(err, tt.wantErr) {
t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err)
}
return
}
if method != "PLAIN" {
t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method)
}
if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) {
t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp)
}
})
}
t.Run("PLAIN sends second server response should fail", func(t *testing.T) {
identity := "foo"
user := "toni.tester@example.com"
pass := "v3ryS3Cur3P4ssw0rd"
server := &ServerInfo{Name: "servername", TLS: true}
auth := PlainAuth(identity, user, pass, "servername", false)
method, resp, err := auth.Start(server)
if err != nil {
got = err.Error()
t.Fatalf("plain authentication failed: %s", err)
}
if got != tt.err {
t.Errorf("%d. got error = %q; want %q", i, got, tt.err)
if method != "PLAIN" {
t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method)
}
if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) {
t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp)
}
_, err = auth.Next([]byte("nonsense"), true)
if err == nil {
t.Fatal("expected second server challange to fail")
}
if !errors.Is(err, ErrUnexpectedServerChallange) {
t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerChallange, err)
}
})
}
/*
func TestAuthPlainNoEnc(t *testing.T) {
tests := []struct {
authName string
@ -2555,3 +2596,6 @@ func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash)
go server.handleConnection(conn)
}
}
*/