mirror of
https://github.com/wneessen/go-mail.git
synced 2024-11-15 02:12:55 +01:00
Refactor and streamline authentication tests
Improved the structure and readability of the authentication tests by using subtests for each scenario, ensuring better isolation and clearer failure reporting. Removed unnecessary imports and redundant code, reducing complexity and enhancing maintainability.
This commit is contained in:
parent
a3fe2f88d5
commit
99c4378107
1 changed files with 125 additions and 81 deletions
|
@ -14,29 +14,9 @@
|
|||
package smtp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"flag"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
|
||||
"github.com/wneessen/go-mail/log"
|
||||
)
|
||||
|
||||
type authTest struct {
|
||||
|
@ -180,28 +160,33 @@ var authTests = []authTest{
|
|||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
testLoop:
|
||||
for i, test := range authTests {
|
||||
name, resp, err := test.auth.Start(&ServerInfo{"testserver", true, nil})
|
||||
if name != test.name {
|
||||
t.Errorf("#%d got name %s, expected %s", i, name, test.name)
|
||||
t.Run("Auth for all supported auth methods", func(t *testing.T) {
|
||||
for i, tt := range authTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
name, resp, err := tt.auth.Start(&ServerInfo{"testserver", true, nil})
|
||||
if name != tt.name {
|
||||
t.Errorf("test #%d got name %s, expected %s", i, name, tt.name)
|
||||
}
|
||||
if !bytes.Equal(resp, []byte(test.responses[0])) {
|
||||
t.Errorf("#%d got response %s, expected %s", i, resp, test.responses[0])
|
||||
if len(tt.responses) <= 0 {
|
||||
t.Fatalf("test #%d got no responses, expected at least one", i)
|
||||
}
|
||||
if !bytes.Equal(resp, []byte(tt.responses[0])) {
|
||||
t.Errorf("#%d got response %s, expected %s", i, resp, tt.responses[0])
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("#%d error: %s", i, err)
|
||||
}
|
||||
for j := range test.challenges {
|
||||
challenge := []byte(test.challenges[j])
|
||||
expected := []byte(test.responses[j+1])
|
||||
sf := test.sf[j]
|
||||
resp, err := test.auth.Next(challenge, true)
|
||||
testLoop:
|
||||
for j := range tt.challenges {
|
||||
challenge := []byte(tt.challenges[j])
|
||||
expected := []byte(tt.responses[j+1])
|
||||
sf := tt.sf[j]
|
||||
resp, err := tt.auth.Next(challenge, true)
|
||||
if err != nil && !sf {
|
||||
t.Errorf("#%d error: %s", i, err)
|
||||
continue testLoop
|
||||
}
|
||||
if test.hasNonce {
|
||||
if tt.hasNonce {
|
||||
if !bytes.HasPrefix(resp, expected) {
|
||||
t.Errorf("#%d got response: %s, expected response to start with: %s", i, resp, expected)
|
||||
}
|
||||
|
@ -211,60 +196,116 @@ testLoop:
|
|||
t.Errorf("#%d got %s, expected %s", i, resp, expected)
|
||||
continue testLoop
|
||||
}
|
||||
_, err = test.auth.Next([]byte("2.7.0 Authentication successful"), false)
|
||||
_, err = tt.auth.Next([]byte("2.7.0 Authentication successful"), false)
|
||||
if err != nil {
|
||||
t.Errorf("#%d success message error: %s", i, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthPlain(t *testing.T) {
|
||||
func TestPlainAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
authName string
|
||||
server *ServerInfo
|
||||
err string
|
||||
shouldFail bool
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "PLAIN auth succeeds",
|
||||
authName: "servername",
|
||||
server: &ServerInfo{Name: "servername", TLS: true},
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
// OK to use PlainAuth on localhost without TLS
|
||||
name: "PLAIN on localhost is allowed to go unencrypted",
|
||||
authName: "localhost",
|
||||
server: &ServerInfo{Name: "localhost", TLS: false},
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
// NOT OK on non-localhost, even if server says PLAIN is OK.
|
||||
// (We don't know that the server is the real server.)
|
||||
name: "PLAIN on non-localhost is not allowed to go unencrypted",
|
||||
authName: "servername",
|
||||
server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}},
|
||||
err: "unencrypted connection",
|
||||
shouldFail: true,
|
||||
wantErr: ErrUnencrypted,
|
||||
},
|
||||
{
|
||||
name: "PLAIN on non-localhost with no PLAIN announcement, is not allowed to go unencrypted",
|
||||
authName: "servername",
|
||||
server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}},
|
||||
err: "unencrypted connection",
|
||||
shouldFail: true,
|
||||
wantErr: ErrUnencrypted,
|
||||
},
|
||||
{
|
||||
name: "PLAIN with wrong hostname",
|
||||
authName: "servername",
|
||||
server: &ServerInfo{Name: "attacker", TLS: true},
|
||||
err: "wrong host name",
|
||||
shouldFail: true,
|
||||
wantErr: ErrWrongHostname,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
auth := PlainAuth("foo", "bar", "baz", tt.authName, false)
|
||||
_, _, err := auth.Start(tt.server)
|
||||
got := ""
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
identity := "foo"
|
||||
user := "toni.tester@example.com"
|
||||
pass := "v3ryS3Cur3P4ssw0rd"
|
||||
auth := PlainAuth(identity, user, pass, tt.authName, false)
|
||||
method, resp, err := auth.Start(tt.server)
|
||||
if err != nil && !tt.shouldFail {
|
||||
t.Errorf("plain authentication failed: %s", err)
|
||||
}
|
||||
if err == nil && tt.shouldFail {
|
||||
t.Error("plain authentication was expected to fail")
|
||||
}
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if method != "PLAIN" {
|
||||
t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method)
|
||||
}
|
||||
if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) {
|
||||
t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp)
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Run("PLAIN sends second server response should fail", func(t *testing.T) {
|
||||
identity := "foo"
|
||||
user := "toni.tester@example.com"
|
||||
pass := "v3ryS3Cur3P4ssw0rd"
|
||||
server := &ServerInfo{Name: "servername", TLS: true}
|
||||
auth := PlainAuth(identity, user, pass, "servername", false)
|
||||
method, resp, err := auth.Start(server)
|
||||
if err != nil {
|
||||
got = err.Error()
|
||||
t.Fatalf("plain authentication failed: %s", err)
|
||||
}
|
||||
if got != tt.err {
|
||||
t.Errorf("%d. got error = %q; want %q", i, got, tt.err)
|
||||
if method != "PLAIN" {
|
||||
t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method)
|
||||
}
|
||||
if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) {
|
||||
t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp)
|
||||
}
|
||||
_, err = auth.Next([]byte("nonsense"), true)
|
||||
if err == nil {
|
||||
t.Fatal("expected second server challange to fail")
|
||||
}
|
||||
if !errors.Is(err, ErrUnexpectedServerChallange) {
|
||||
t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerChallange, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
func TestAuthPlainNoEnc(t *testing.T) {
|
||||
tests := []struct {
|
||||
authName string
|
||||
|
@ -2555,3 +2596,6 @@ func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash)
|
|||
go server.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
*/
|
||||
|
|
Loading…
Reference in a new issue