diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 4fe0481..a7f3236 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -14,29 +14,9 @@ package smtp import ( - "bufio" "bytes" - "crypto/hmac" - "crypto/sha1" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "flag" - "fmt" - "hash" - "io" - "net" - "net/textproto" - "os" - "runtime" - "strings" + "errors" "testing" - "time" - - "golang.org/x/crypto/pbkdf2" - - "github.com/wneessen/go-mail/log" ) type authTest struct { @@ -180,91 +160,152 @@ var authTests = []authTest{ } func TestAuth(t *testing.T) { -testLoop: - for i, test := range authTests { - name, resp, err := test.auth.Start(&ServerInfo{"testserver", true, nil}) - if name != test.name { - t.Errorf("#%d got name %s, expected %s", i, name, test.name) - } - if !bytes.Equal(resp, []byte(test.responses[0])) { - t.Errorf("#%d got response %s, expected %s", i, resp, test.responses[0]) - } - if err != nil { - t.Errorf("#%d error: %s", i, err) - } - for j := range test.challenges { - challenge := []byte(test.challenges[j]) - expected := []byte(test.responses[j+1]) - sf := test.sf[j] - resp, err := test.auth.Next(challenge, true) - if err != nil && !sf { - t.Errorf("#%d error: %s", i, err) - continue testLoop - } - if test.hasNonce { - if !bytes.HasPrefix(resp, expected) { - t.Errorf("#%d got response: %s, expected response to start with: %s", i, resp, expected) + t.Run("Auth for all supported auth methods", func(t *testing.T) { + for i, tt := range authTests { + t.Run(tt.name, func(t *testing.T) { + name, resp, err := tt.auth.Start(&ServerInfo{"testserver", true, nil}) + if name != tt.name { + t.Errorf("test #%d got name %s, expected %s", i, name, tt.name) } - continue testLoop - } - if !bytes.Equal(resp, expected) { - t.Errorf("#%d got %s, expected %s", i, resp, expected) - continue testLoop - } - _, err = test.auth.Next([]byte("2.7.0 Authentication successful"), false) - if err != nil { - t.Errorf("#%d success message error: %s", i, err) - } + if len(tt.responses) <= 0 { + t.Fatalf("test #%d got no responses, expected at least one", i) + } + if !bytes.Equal(resp, []byte(tt.responses[0])) { + t.Errorf("#%d got response %s, expected %s", i, resp, tt.responses[0]) + } + if err != nil { + t.Errorf("#%d error: %s", i, err) + } + testLoop: + for j := range tt.challenges { + challenge := []byte(tt.challenges[j]) + expected := []byte(tt.responses[j+1]) + sf := tt.sf[j] + resp, err := tt.auth.Next(challenge, true) + if err != nil && !sf { + t.Errorf("#%d error: %s", i, err) + continue testLoop + } + if tt.hasNonce { + if !bytes.HasPrefix(resp, expected) { + t.Errorf("#%d got response: %s, expected response to start with: %s", i, resp, expected) + } + continue testLoop + } + if !bytes.Equal(resp, expected) { + t.Errorf("#%d got %s, expected %s", i, resp, expected) + continue testLoop + } + _, err = tt.auth.Next([]byte("2.7.0 Authentication successful"), false) + if err != nil { + t.Errorf("#%d success message error: %s", i, err) + } + } + }) } - } + }) } -func TestAuthPlain(t *testing.T) { +func TestPlainAuth(t *testing.T) { tests := []struct { - authName string - server *ServerInfo - err string + name string + authName string + server *ServerInfo + shouldFail bool + wantErr error }{ { - authName: "servername", - server: &ServerInfo{Name: "servername", TLS: true}, + name: "PLAIN auth succeeds", + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + shouldFail: false, }, { // OK to use PlainAuth on localhost without TLS - authName: "localhost", - server: &ServerInfo{Name: "localhost", TLS: false}, + name: "PLAIN on localhost is allowed to go unencrypted", + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + shouldFail: false, }, { // NOT OK on non-localhost, even if server says PLAIN is OK. // (We don't know that the server is the real server.) - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, - err: "unencrypted connection", + name: "PLAIN on non-localhost is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + shouldFail: true, + wantErr: ErrUnencrypted, }, { - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, - err: "unencrypted connection", + name: "PLAIN on non-localhost with no PLAIN announcement, is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + shouldFail: true, + wantErr: ErrUnencrypted, }, { - authName: "servername", - server: &ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", + name: "PLAIN with wrong hostname", + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + shouldFail: true, + wantErr: ErrWrongHostname, }, } - for i, tt := range tests { - auth := PlainAuth("foo", "bar", "baz", tt.authName, false) - _, _, err := auth.Start(tt.server) - got := "" + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + identity := "foo" + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + auth := PlainAuth(identity, user, pass, tt.authName, false) + method, resp, err := auth.Start(tt.server) + if err != nil && !tt.shouldFail { + t.Errorf("plain authentication failed: %s", err) + } + if err == nil && tt.shouldFail { + t.Error("plain authentication was expected to fail") + } + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err) + } + return + } + if method != "PLAIN" { + t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method) + } + if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) { + t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp) + } + }) + } + t.Run("PLAIN sends second server response should fail", func(t *testing.T) { + identity := "foo" + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + server := &ServerInfo{Name: "servername", TLS: true} + auth := PlainAuth(identity, user, pass, "servername", false) + method, resp, err := auth.Start(server) if err != nil { - got = err.Error() + t.Fatalf("plain authentication failed: %s", err) } - if got != tt.err { - t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + if method != "PLAIN" { + t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method) } - } + if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) { + t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp) + } + _, err = auth.Next([]byte("nonsense"), true) + if err == nil { + t.Fatal("expected second server challange to fail") + } + if !errors.Is(err, ErrUnexpectedServerChallange) { + t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerChallange, err) + } + }) } +/* + func TestAuthPlainNoEnc(t *testing.T) { tests := []struct { authName string @@ -2555,3 +2596,6 @@ func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash) go server.handleConnection(conn) } } + + +*/