From a3fe2f88d58b7107bf88bbd1ca05086fb4369bdd Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 1 Nov 2024 18:47:21 +0100 Subject: [PATCH 01/45] Add test for MessageID on nil SendError This test ensures that when MessageID is called on a nil SendError, it returns an empty string. This additional check helps verify the correct behavior of the MessageID method under nil conditions. --- senderror_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/senderror_test.go b/senderror_test.go index 3cd3a76..8584e32 100644 --- a/senderror_test.go +++ b/senderror_test.go @@ -174,6 +174,12 @@ func TestSendError_MessageID(t *testing.T) { t.Errorf("sendError expected empty message-id, got: %s", sendErr.MessageID()) } }) + t.Run("TestSendError_MessageID on nil error should return empty", func(t *testing.T) { + var sendErr *SendError + if sendErr.MessageID() != "" { + t.Error("expected empty message-id on nil-senderror") + } + }) } func TestSendError_Msg(t *testing.T) { From 99c4378107d42a6fb9ae9e41446b24d14bee1088 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 1 Nov 2024 19:22:28 +0100 Subject: [PATCH 02/45] Refactor and streamline authentication tests Improved the structure and readability of the authentication tests by using subtests for each scenario, ensuring better isolation and clearer failure reporting. Removed unnecessary imports and redundant code, reducing complexity and enhancing maintainability. --- smtp/smtp_test.go | 206 ++++++++++++++++++++++++++++------------------ 1 file changed, 125 insertions(+), 81 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 4fe0481..a7f3236 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -14,29 +14,9 @@ package smtp import ( - "bufio" "bytes" - "crypto/hmac" - "crypto/sha1" - "crypto/sha256" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "flag" - "fmt" - "hash" - "io" - "net" - "net/textproto" - "os" - "runtime" - "strings" + "errors" "testing" - "time" - - "golang.org/x/crypto/pbkdf2" - - "github.com/wneessen/go-mail/log" ) type authTest struct { @@ -180,91 +160,152 @@ var authTests = []authTest{ } func TestAuth(t *testing.T) { -testLoop: - for i, test := range authTests { - name, resp, err := test.auth.Start(&ServerInfo{"testserver", true, nil}) - if name != test.name { - t.Errorf("#%d got name %s, expected %s", i, name, test.name) - } - if !bytes.Equal(resp, []byte(test.responses[0])) { - t.Errorf("#%d got response %s, expected %s", i, resp, test.responses[0]) - } - if err != nil { - t.Errorf("#%d error: %s", i, err) - } - for j := range test.challenges { - challenge := []byte(test.challenges[j]) - expected := []byte(test.responses[j+1]) - sf := test.sf[j] - resp, err := test.auth.Next(challenge, true) - if err != nil && !sf { - t.Errorf("#%d error: %s", i, err) - continue testLoop - } - if test.hasNonce { - if !bytes.HasPrefix(resp, expected) { - t.Errorf("#%d got response: %s, expected response to start with: %s", i, resp, expected) + t.Run("Auth for all supported auth methods", func(t *testing.T) { + for i, tt := range authTests { + t.Run(tt.name, func(t *testing.T) { + name, resp, err := tt.auth.Start(&ServerInfo{"testserver", true, nil}) + if name != tt.name { + t.Errorf("test #%d got name %s, expected %s", i, name, tt.name) } - continue testLoop - } - if !bytes.Equal(resp, expected) { - t.Errorf("#%d got %s, expected %s", i, resp, expected) - continue testLoop - } - _, err = test.auth.Next([]byte("2.7.0 Authentication successful"), false) - if err != nil { - t.Errorf("#%d success message error: %s", i, err) - } + if len(tt.responses) <= 0 { + t.Fatalf("test #%d got no responses, expected at least one", i) + } + if !bytes.Equal(resp, []byte(tt.responses[0])) { + t.Errorf("#%d got response %s, expected %s", i, resp, tt.responses[0]) + } + if err != nil { + t.Errorf("#%d error: %s", i, err) + } + testLoop: + for j := range tt.challenges { + challenge := []byte(tt.challenges[j]) + expected := []byte(tt.responses[j+1]) + sf := tt.sf[j] + resp, err := tt.auth.Next(challenge, true) + if err != nil && !sf { + t.Errorf("#%d error: %s", i, err) + continue testLoop + } + if tt.hasNonce { + if !bytes.HasPrefix(resp, expected) { + t.Errorf("#%d got response: %s, expected response to start with: %s", i, resp, expected) + } + continue testLoop + } + if !bytes.Equal(resp, expected) { + t.Errorf("#%d got %s, expected %s", i, resp, expected) + continue testLoop + } + _, err = tt.auth.Next([]byte("2.7.0 Authentication successful"), false) + if err != nil { + t.Errorf("#%d success message error: %s", i, err) + } + } + }) } - } + }) } -func TestAuthPlain(t *testing.T) { +func TestPlainAuth(t *testing.T) { tests := []struct { - authName string - server *ServerInfo - err string + name string + authName string + server *ServerInfo + shouldFail bool + wantErr error }{ { - authName: "servername", - server: &ServerInfo{Name: "servername", TLS: true}, + name: "PLAIN auth succeeds", + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + shouldFail: false, }, { // OK to use PlainAuth on localhost without TLS - authName: "localhost", - server: &ServerInfo{Name: "localhost", TLS: false}, + name: "PLAIN on localhost is allowed to go unencrypted", + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + shouldFail: false, }, { // NOT OK on non-localhost, even if server says PLAIN is OK. // (We don't know that the server is the real server.) - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, - err: "unencrypted connection", + name: "PLAIN on non-localhost is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + shouldFail: true, + wantErr: ErrUnencrypted, }, { - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, - err: "unencrypted connection", + name: "PLAIN on non-localhost with no PLAIN announcement, is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + shouldFail: true, + wantErr: ErrUnencrypted, }, { - authName: "servername", - server: &ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", + name: "PLAIN with wrong hostname", + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + shouldFail: true, + wantErr: ErrWrongHostname, }, } - for i, tt := range tests { - auth := PlainAuth("foo", "bar", "baz", tt.authName, false) - _, _, err := auth.Start(tt.server) - got := "" + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + identity := "foo" + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + auth := PlainAuth(identity, user, pass, tt.authName, false) + method, resp, err := auth.Start(tt.server) + if err != nil && !tt.shouldFail { + t.Errorf("plain authentication failed: %s", err) + } + if err == nil && tt.shouldFail { + t.Error("plain authentication was expected to fail") + } + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err) + } + return + } + if method != "PLAIN" { + t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method) + } + if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) { + t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp) + } + }) + } + t.Run("PLAIN sends second server response should fail", func(t *testing.T) { + identity := "foo" + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + server := &ServerInfo{Name: "servername", TLS: true} + auth := PlainAuth(identity, user, pass, "servername", false) + method, resp, err := auth.Start(server) if err != nil { - got = err.Error() + t.Fatalf("plain authentication failed: %s", err) } - if got != tt.err { - t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + if method != "PLAIN" { + t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method) } - } + if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) { + t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp) + } + _, err = auth.Next([]byte("nonsense"), true) + if err == nil { + t.Fatal("expected second server challange to fail") + } + if !errors.Is(err, ErrUnexpectedServerChallange) { + t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerChallange, err) + } + }) } +/* + func TestAuthPlainNoEnc(t *testing.T) { tests := []struct { authName string @@ -2555,3 +2596,6 @@ func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash) go server.handleConnection(conn) } } + + +*/ From 3cfd20576d9d37bee26db07868f2b442a12eafbc Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sun, 3 Nov 2024 16:13:54 +0100 Subject: [PATCH 03/45] Rename and expand `TestPlainAuth_noEnc` with additional checks Refactor the test function `TestPlainAuth_noEnc` to include subtests for better organization and add more comprehensive error handling. This improves clarity and robustness by verifying various authentication scenarios and expected outcomes. --- smtp/smtp_test.go | 106 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 81 insertions(+), 25 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index a7f3236..3022018 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -304,48 +304,104 @@ func TestPlainAuth(t *testing.T) { }) } -/* - -func TestAuthPlainNoEnc(t *testing.T) { +func TestPlainAuth_noEnc(t *testing.T) { tests := []struct { - authName string - server *ServerInfo - err string + name string + authName string + server *ServerInfo + shouldFail bool + wantErr error }{ { - authName: "servername", - server: &ServerInfo{Name: "servername", TLS: true}, + name: "PLAIN-NOENC auth succeeds", + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + shouldFail: false, }, { // OK to use PlainAuth on localhost without TLS - authName: "localhost", - server: &ServerInfo{Name: "localhost", TLS: false}, + name: "PLAIN-NOENC on localhost is allowed to go unencrypted", + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + shouldFail: false, }, { - // Also OK on non-TLS secured connections. The NoEnc mechanism is meant to allow - // non-encrypted connections. - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + // ALSO OK on non-localhost. This auth mode is specificly for that. + name: "PLAIN-NOENC on non-localhost is allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + shouldFail: false, }, { - authName: "servername", - server: &ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", + name: "PLAIN-NOENC on non-localhost with no PLAIN announcement, is allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + shouldFail: false, + }, + { + name: "PLAIN-NOENC with wrong hostname", + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + shouldFail: true, + wantErr: ErrWrongHostname, }, } - for i, tt := range tests { - auth := PlainAuth("foo", "bar", "baz", tt.authName, true) - _, _, err := auth.Start(tt.server) - got := "" + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + identity := "foo" + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + auth := PlainAuth(identity, user, pass, tt.authName, true) + method, resp, err := auth.Start(tt.server) + if err != nil && !tt.shouldFail { + t.Errorf("plain authentication failed: %s", err) + } + if err == nil && tt.shouldFail { + t.Error("plain authentication was expected to fail") + } + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err) + } + return + } + if method != "PLAIN" { + t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method) + } + if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) { + t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp) + } + }) + } + t.Run("PLAIN-NOENC sends second server response should fail", func(t *testing.T) { + identity := "foo" + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + server := &ServerInfo{Name: "servername", TLS: true} + auth := PlainAuth(identity, user, pass, "servername", true) + method, resp, err := auth.Start(server) if err != nil { - got = err.Error() + t.Fatalf("plain authentication failed: %s", err) } - if got != tt.err { - t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + if method != "PLAIN" { + t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method) } - } + if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) { + t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp) + } + _, err = auth.Next([]byte("nonsense"), true) + if err == nil { + t.Fatal("expected second server challange to fail") + } + if !errors.Is(err, ErrUnexpectedServerChallange) { + t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerChallange, err) + } + }) } +/* + + func TestAuthLogin(t *testing.T) { tests := []struct { authName string From 2391010e3a77aef60666408e9eadba7675dc2758 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Thu, 7 Nov 2024 20:58:20 +0100 Subject: [PATCH 04/45] Rename parameter for consistency in auth functions Updated the parameter name `allowUnEnc` to `allowUnenc` in both `LoginAuth` and `PlainAuth` functions to maintain consistent naming conventions. This change improves code readability and follows standard naming practices. --- smtp/auth_login.go | 4 ++-- smtp/auth_plain.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/smtp/auth_login.go b/smtp/auth_login.go index b5f1065..c40e48c 100644 --- a/smtp/auth_login.go +++ b/smtp/auth_login.go @@ -36,8 +36,8 @@ type loginAuth struct { // LoginAuth will only send the credentials if the connection is using TLS // or is connected to localhost. Otherwise authentication will fail with an // error, without sending the credentials. -func LoginAuth(username, password, host string, allowUnEnc bool) Auth { - return &loginAuth{username, password, host, 0, allowUnEnc} +func LoginAuth(username, password, host string, allowUnenc bool) Auth { + return &loginAuth{username, password, host, 0, allowUnenc} } // Start begins the SMTP authentication process by validating server's TLS status and hostname. diff --git a/smtp/auth_plain.go b/smtp/auth_plain.go index f2ea8ac..39e2a2f 100644 --- a/smtp/auth_plain.go +++ b/smtp/auth_plain.go @@ -28,8 +28,8 @@ type plainAuth struct { // PlainAuth will only send the credentials if the connection is using TLS // or is connected to localhost. Otherwise authentication will fail with an // error, without sending the credentials. -func PlainAuth(identity, username, password, host string, allowUnEnc bool) Auth { - return &plainAuth{identity, username, password, host, allowUnEnc} +func PlainAuth(identity, username, password, host string, allowUnenc bool) Auth { + return &plainAuth{identity, username, password, host, allowUnenc} } func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { From 410343496c138daab115fac155d8f0d6b74a10ad Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Thu, 7 Nov 2024 21:14:52 +0100 Subject: [PATCH 05/45] Refactor and expand TestLoginAuth Rename and uncomment TestLoginAuth with more test cases, ensuring coverage for successful and failed authentication scenarios, including checks for unencrypted logins and server response errors. This improves test robustness and coverage. --- smtp/smtp_test.go | 95 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 4 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 3022018..271ea5c 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -399,6 +399,97 @@ func TestPlainAuth_noEnc(t *testing.T) { }) } +func TestLoginAuth(t *testing.T) { + tests := []struct { + name string + authName string + server *ServerInfo + shouldFail bool + wantErr error + }{ + { + name: "LOGIN auth succeeds", + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + shouldFail: false, + }, + { + // OK to use PlainAuth on localhost without TLS + name: "LOGIN on localhost is allowed to go unencrypted", + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + shouldFail: false, + }, + { + // NOT OK on non-localhost, even if server says LOGIN is OK. + // (We don't know that the server is the real server.) + name: "LOGIN on non-localhost is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"LOGIN"}}, + shouldFail: true, + wantErr: ErrUnencrypted, + }, + { + name: "LOGIN on non-localhost with no LOGIN announcement, is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + shouldFail: true, + wantErr: ErrUnencrypted, + }, + { + name: "LOGIN with wrong hostname", + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + shouldFail: true, + wantErr: ErrWrongHostname, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + auth := LoginAuth(user, pass, tt.authName, false) + method, _, err := auth.Start(tt.server) + if err != nil && !tt.shouldFail { + t.Errorf("plain authentication failed: %s", err) + } + if err == nil && tt.shouldFail { + t.Error("plain authentication was expected to fail") + } + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err) + } + return + } + if method != "LOGIN" { + t.Errorf("expected method return to be: %q, got: %q", "LOGIN", method) + } + resp, err := auth.Next([]byte(user), true) + if err != nil { + t.Errorf("failed on first server challange: %s", err) + } + if !bytes.Equal([]byte(user), resp) { + t.Errorf("expected response to first challange to be: %q, got: %q", user, resp) + } + resp, err = auth.Next([]byte(pass), true) + if err != nil { + t.Errorf("failed on second server challange: %s", err) + } + if !bytes.Equal([]byte(pass), resp) { + t.Errorf("expected response to second challange to be: %q, got: %q", pass, resp) + } + resp, err = auth.Next([]byte("nonsense"), true) + if err == nil { + t.Error("expected third server challange to fail, but didn't") + } + if !errors.Is(err, ErrUnexpectedServerResponse) { + t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerResponse, err) + } + }) + } +} + /* @@ -408,10 +499,6 @@ func TestAuthLogin(t *testing.T) { server *ServerInfo err string }{ - { - authName: "servername", - server: &ServerInfo{Name: "servername", TLS: true}, - }, { // OK to use LoginAuth on localhost without TLS authName: "localhost", From 4221d48644d85d0fd624875d2b08cf04c6b29fe4 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Thu, 7 Nov 2024 21:31:24 +0100 Subject: [PATCH 06/45] Update login authentication test cases in smtp_test.go Renamed the test functions and improved the test structure for login authentication checks. Added subtests to provide clear descriptions and enhance error checking. --- smtp/smtp_test.go | 176 ++++++++++++++++++++++++---------------------- 1 file changed, 90 insertions(+), 86 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 271ea5c..be578fb 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -451,10 +451,98 @@ func TestLoginAuth(t *testing.T) { auth := LoginAuth(user, pass, tt.authName, false) method, _, err := auth.Start(tt.server) if err != nil && !tt.shouldFail { - t.Errorf("plain authentication failed: %s", err) + t.Errorf("login authentication failed: %s", err) } if err == nil && tt.shouldFail { - t.Error("plain authentication was expected to fail") + t.Error("login authentication was expected to fail") + } + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err) + } + return + } + if method != "LOGIN" { + t.Errorf("expected method return to be: %q, got: %q", "LOGIN", method) + } + resp, err := auth.Next([]byte(user), true) + if err != nil { + t.Errorf("failed on first server challange: %s", err) + } + if !bytes.Equal([]byte(user), resp) { + t.Errorf("expected response to first challange to be: %q, got: %q", user, resp) + } + resp, err = auth.Next([]byte(pass), true) + if err != nil { + t.Errorf("failed on second server challange: %s", err) + } + if !bytes.Equal([]byte(pass), resp) { + t.Errorf("expected response to second challange to be: %q, got: %q", pass, resp) + } + resp, err = auth.Next([]byte("nonsense"), true) + if err == nil { + t.Error("expected third server challange to fail, but didn't") + } + if !errors.Is(err, ErrUnexpectedServerResponse) { + t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerResponse, err) + } + }) + } +} + +func TestLoginAuth_noEnc(t *testing.T) { + tests := []struct { + name string + authName string + server *ServerInfo + shouldFail bool + wantErr error + }{ + { + name: "LOGIN-NOENC auth succeeds", + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + shouldFail: false, + }, + { + // OK to use PlainAuth on localhost without TLS + name: "LOGIN-NOENC on localhost is allowed to go unencrypted", + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + shouldFail: false, + }, + { + // ALSO OK on non-localhost. This auth mode is specificly for that. + name: "LOGIN-NOENC on non-localhost is allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"LOGIN"}}, + shouldFail: false, + }, + { + name: "LOGIN-NOENC on non-localhost with no LOGIN announcement, is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + shouldFail: false, + }, + { + name: "LOGIN-NOENC with wrong hostname", + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + shouldFail: true, + wantErr: ErrWrongHostname, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + auth := LoginAuth(user, pass, tt.authName, true) + method, _, err := auth.Start(tt.server) + if err != nil && !tt.shouldFail { + t.Errorf("login authentication failed: %s", err) + } + if err == nil && tt.shouldFail { + t.Error("login authentication was expected to fail") } if tt.wantErr != nil { if !errors.Is(err, tt.wantErr) { @@ -493,91 +581,7 @@ func TestLoginAuth(t *testing.T) { /* -func TestAuthLogin(t *testing.T) { - tests := []struct { - authName string - server *ServerInfo - err string - }{ - { - // OK to use LoginAuth on localhost without TLS - authName: "localhost", - server: &ServerInfo{Name: "localhost", TLS: false}, - }, - { - // NOT OK on non-localhost, even if server says PLAIN is OK. - // (We don't know that the server is the real server.) - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"LOGIN"}}, - err: "unencrypted connection", - }, - { - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, - err: "unencrypted connection", - }, - { - authName: "servername", - server: &ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", - }, - } - for i, tt := range tests { - auth := LoginAuth("foo", "bar", tt.authName, false) - _, _, err := auth.Start(tt.server) - got := "" - if err != nil { - got = err.Error() - } - if got != tt.err { - t.Errorf("%d. got error = %q; want %q", i, got, tt.err) - } - } -} -func TestAuthLoginNoEnc(t *testing.T) { - tests := []struct { - authName string - server *ServerInfo - err string - }{ - { - authName: "servername", - server: &ServerInfo{Name: "servername", TLS: true}, - }, - { - // OK to use LoginAuth on localhost without TLS - authName: "localhost", - server: &ServerInfo{Name: "localhost", TLS: false}, - }, - { - // Also OK on non-TLS secured connections. The NoEnc mechanism is meant to allow - // non-encrypted connections. - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"LOGIN"}}, - }, - { - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, - }, - { - authName: "servername", - server: &ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", - }, - } - for i, tt := range tests { - auth := LoginAuth("foo", "bar", tt.authName, true) - _, _, err := auth.Start(tt.server) - got := "" - if err != nil { - got = err.Error() - } - if got != tt.err { - t.Errorf("%d. got error = %q; want %q", i, got, tt.err) - } - } -} func TestXOAuth2OK(t *testing.T) { server := []string{ From b03fbb4ae8675dcf7ffd12730918efd079189674 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Thu, 7 Nov 2024 22:42:23 +0100 Subject: [PATCH 07/45] Add test server for SMTP authentication Added a simple SMTP test server with basic features like PLAIN, LOGIN, and NOENC authentication. It can start, handle connections, and simulate authentication success or failure. Included support for TLS with a generated localhost certificate. --- smtp/smtp_test.go | 496 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 496 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index be578fb..efc209d 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -14,11 +14,69 @@ package smtp import ( + "bufio" "bytes" + "context" + "crypto/tls" "errors" + "fmt" + "net" + "strings" + "sync/atomic" "testing" + "time" ) +const ( + // TestServerProto is the protocol used for the simple SMTP test server + TestServerProto = "tcp" + // TestServerAddr is the address the simple SMTP test server listens on + TestServerAddr = "127.0.0.1" + // TestServerPortBase is the base port for the simple SMTP test server + TestServerPortBase = 12025 +) + +// PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances. +var PortAdder atomic.Int32 + +// localhostCert is a PEM-encoded TLS cert generated from src/crypto/tls: +// +// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com \ +// --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var localhostCert = []byte(` +-----BEGIN CERTIFICATE----- +MIICFDCCAX2gAwIBAgIRAK0xjnaPuNDSreeXb+z+0u4wDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2 +MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw +gYkCgYEA0nFbQQuOWsjbGtejcpWz153OlziZM4bVjJ9jYruNw5n2Ry6uYQAffhqa +JOInCmmcVe2siJglsyH9aRh6vKiobBbIUXXUU1ABd56ebAzlt0LobLlx7pZEMy30 +LqIi9E6zmL3YvdGzpYlkFRnRrqwEtWYbGBf3znO250S56CCWH2UCAwEAAaNoMGYw +DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF +MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA +AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAbZtDS2dVuBYvb+MnolWnCNqvw1w5Gtgi +NmvQQPOMgM3m+oQSCPRTNGSg25e1Qbo7bgQDv8ZTnq8FgOJ/rbkyERw2JckkHpD4 +n4qcK27WkEDBtQFlPihIM8hLIuzWoi/9wygiElTy/tVL3y7fGCvY2/k1KBthtZGF +tN8URjVmyEo= +-----END CERTIFICATE-----`) + +// localhostKey is the private key for localhostCert. +var localhostKey = []byte(testingKey(` +-----BEGIN RSA TESTING KEY----- +MIICXgIBAAKBgQDScVtBC45ayNsa16NylbPXnc6XOJkzhtWMn2Niu43DmfZHLq5h +AB9+Gpok4icKaZxV7ayImCWzIf1pGHq8qKhsFshRddRTUAF3np5sDOW3QuhsuXHu +lkQzLfQuoiL0TrOYvdi90bOliWQVGdGurAS1ZhsYF/fOc7bnRLnoIJYfZQIDAQAB +AoGBAMst7OgpKyFV6c3JwyI/jWqxDySL3caU+RuTTBaodKAUx2ZEmNJIlx9eudLA +kucHvoxsM/eRxlxkhdFxdBcwU6J+zqooTnhu/FE3jhrT1lPrbhfGhyKnUrB0KKMM +VY3IQZyiehpxaeXAwoAou6TbWoTpl9t8ImAqAMY8hlULCUqlAkEA+9+Ry5FSYK/m +542LujIcCaIGoG1/Te6Sxr3hsPagKC2rH20rDLqXwEedSFOpSS0vpzlPAzy/6Rbb +PHTJUhNdwwJBANXkA+TkMdbJI5do9/mn//U0LfrCR9NkcoYohxfKz8JuhgRQxzF2 +6jpo3q7CdTuuRixLWVfeJzcrAyNrVcBq87cCQFkTCtOMNC7fZnCTPUv+9q1tcJyB +vNjJu3yvoEZeIeuzouX9TJE21/33FaeDdsXbRhQEj23cqR38qFHsF1qAYNMCQQDP +QXLEiJoClkR2orAmqjPLVhR3t2oB3INcnEjLNSq8LHyQEfXyaFfu4U9l5+fRPL2i +jiC0k/9L5dHUsF0XZothAkEA23ddgRs+Id/HxtojqqUT27B8MT/IGNrYsp4DvS/c +qgkeluku4GjxRlDMBuXk94xOBEinUs+p/hwP1Alll80Tpg== +-----END RSA TESTING KEY-----`)) + type authTest struct { auth Auth challenges []string @@ -302,6 +360,59 @@ func TestPlainAuth(t *testing.T) { t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerChallange, err) } }) + t.Run("PLAIN authentication on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort}, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := PlainAuth("", "user", "pass", TestServerAddr, false) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run("PLAIN authentication on test server should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort}, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := PlainAuth("", "user", "pass", TestServerAddr, false) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + }) } func TestPlainAuth_noEnc(t *testing.T) { @@ -397,6 +508,59 @@ func TestPlainAuth_noEnc(t *testing.T) { t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerChallange, err) } }) + t.Run("PLAIN-NOENC authentication on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort}, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := PlainAuth("", "user", "pass", TestServerAddr, true) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run("PLAIN-NOENC authentication on test server should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort}, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := PlainAuth("", "user", "pass", TestServerAddr, true) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + }) } func TestLoginAuth(t *testing.T) { @@ -488,6 +652,59 @@ func TestLoginAuth(t *testing.T) { } }) } + t.Run("LOGIN authentication on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort}, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := LoginAuth("user", "pass", TestServerAddr, false) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run("LOGIN authentication on test server should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort}, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := LoginAuth("user", "pass", TestServerAddr, false) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + }) } func TestLoginAuth_noEnc(t *testing.T) { @@ -576,6 +793,59 @@ func TestLoginAuth_noEnc(t *testing.T) { } }) } + t.Run("LOGIN-NOENC authentication on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort}, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := LoginAuth("user", "pass", TestServerAddr, true) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run("LOGIN-NOENC authentication on test server should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort}, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := LoginAuth("user", "pass", TestServerAddr, true) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + }) } /* @@ -2746,3 +3016,229 @@ func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash) */ +// testingKey replaces the substring "TESTING KEY" with "PRIVATE KEY" in the given string s. +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } + +// serverProps represents the configuration properties for the SMTP server. +type serverProps struct { + FailOnAuth bool + FailOnDataInit bool + FailOnDataClose bool + FailOnHelo bool + FailOnMailFrom bool + FailOnNoop bool + FailOnQuit bool + FailOnReset bool + FailOnSTARTTLS bool + FailTemp bool + FeatureSet string + ListenPort int + SSLListener bool + IsTLS bool + SupportDSN bool +} + +// simpleSMTPServer starts a simple TCP server that resonds to SMTP commands. +// The provided featureSet represents in what the server responds to EHLO command +// failReset controls if a RSET succeeds +func simpleSMTPServer(ctx context.Context, t *testing.T, props *serverProps) error { + t.Helper() + if props == nil { + return fmt.Errorf("no server properties provided") + } + + var listener net.Listener + var err error + if props.SSLListener { + keypair, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + return fmt.Errorf("failed to read TLS keypair: %w", err) + } + tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}} + listener, err = tls.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort), + tlsConfig) + if err != nil { + t.Fatalf("failed to create TLS listener: %s", err) + } + } else { + listener, err = net.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort)) + } + if err != nil { + return fmt.Errorf("unable to listen on %s://%s: %w (SSL: %t)", TestServerProto, TestServerAddr, err, + props.SSLListener) + } + + defer func() { + if err := listener.Close(); err != nil { + t.Logf("failed to close listener: %s", err) + } + }() + + for { + select { + case <-ctx.Done(): + return nil + default: + connection, err := listener.Accept() + var opErr *net.OpError + if err != nil { + if errors.As(err, &opErr) && opErr.Temporary() { + continue + } + return fmt.Errorf("unable to accept connection: %w", err) + } + handleTestServerConnection(connection, t, props) + } + } +} + +func handleTestServerConnection(connection net.Conn, t *testing.T, props *serverProps) { + t.Helper() + if !props.IsTLS { + t.Cleanup(func() { + if err := connection.Close(); err != nil { + t.Logf("failed to close connection: %s", err) + } + }) + } + + reader := bufio.NewReader(connection) + writer := bufio.NewWriter(connection) + + writeLine := func(data string) { + _, err := writer.WriteString(data + "\r\n") + if err != nil { + t.Logf("failed to write line: %s", err) + } + _ = writer.Flush() + } + writeOK := func() { + writeLine("250 2.0.0 OK") + } + + if !props.IsTLS { + writeLine("220 go-mail test server ready ESMTP") + } + + for { + data, err := reader.ReadString('\n') + if err != nil { + break + } + time.Sleep(time.Millisecond) + + var datastring string + data = strings.TrimSpace(data) + switch { + case strings.HasPrefix(data, "EHLO"), strings.HasPrefix(data, "HELO"): + if len(strings.Split(data, " ")) != 2 { + writeLine("501 Syntax: EHLO hostname") + break + } + if props.FailOnHelo { + writeLine("500 5.5.2 Error: fail on HELO") + break + } + writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) + case strings.HasPrefix(data, "MAIL FROM:"): + if props.FailOnMailFrom { + writeLine("500 5.5.2 Error: fail on MAIL FROM") + break + } + from := strings.TrimPrefix(data, "MAIL FROM:") + from = strings.ReplaceAll(from, "BODY=8BITMIME", "") + from = strings.ReplaceAll(from, "SMTPUTF8", "") + if props.SupportDSN { + from = strings.ReplaceAll(from, "RET=FULL", "") + } + from = strings.TrimSpace(from) + if !strings.EqualFold(from, "") { + writeLine(fmt.Sprintf("503 5.1.2 Invalid from: %s", from)) + break + } + writeOK() + case strings.HasPrefix(data, "RCPT TO:"): + to := strings.TrimPrefix(data, "RCPT TO:") + if props.SupportDSN { + to = strings.ReplaceAll(to, "NOTIFY=FAILURE,SUCCESS", "") + } + to = strings.TrimSpace(to) + if !strings.EqualFold(to, "") { + writeLine(fmt.Sprintf("500 5.1.2 Invalid to: %s", to)) + break + } + writeOK() + case strings.HasPrefix(data, "AUTH"): + if props.FailOnAuth { + writeLine("535 5.7.8 Error: authentication failed") + break + } + writeLine("235 2.7.0 Authentication successful") + case strings.EqualFold(data, "DATA"): + if props.FailOnDataInit { + writeLine("503 5.5.1 Error: fail on DATA init") + break + } + writeLine("354 End data with .") + for { + ddata, derr := reader.ReadString('\n') + if derr != nil { + t.Logf("failed to read data from connection: %s", derr) + break + } + ddata = strings.TrimSpace(ddata) + if ddata == "." { + if props.FailOnDataClose { + writeLine("500 5.0.0 Error during DATA transmission") + break + } + if props.FailTemp { + writeLine("451 4.3.0 Error: fail on DATA close") + break + } + writeLine("250 2.0.0 Ok: queued as 1234567890") + break + } + datastring += ddata + "\n" + } + case strings.EqualFold(data, "noop"): + if props.FailOnNoop { + writeLine("500 5.0.0 Error: fail on NOOP") + break + } + writeOK() + case strings.EqualFold(data, "vrfy"): + writeOK() + case strings.EqualFold(data, "rset"): + if props.FailOnReset { + writeLine("500 5.1.2 Error: reset failed") + break + } + writeOK() + case strings.EqualFold(data, "quit"): + if props.FailOnQuit { + writeLine("500 5.1.2 Error: quit failed") + break + } + writeLine("221 2.0.0 Bye") + return + case strings.EqualFold(data, "starttls"): + if props.FailOnSTARTTLS { + writeLine("500 5.1.2 Error: starttls failed") + break + } + keypair, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + writeLine("500 5.1.2 Error: starttls failed - " + err.Error()) + break + } + writeLine("220 Ready to start TLS") + tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}} + connection = tls.Server(connection, tlsConfig) + props.IsTLS = true + handleTestServerConnection(connection, t, props) + default: + writeLine("500 5.5.2 Error: bad syntax") + } + } +} From 1af17f14e12b3501226a69392a7df31cb64bcce6 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 8 Nov 2024 15:11:47 +0100 Subject: [PATCH 08/45] Add cleanup to close client connections in tests This commit enhances the cleanup process in the SMTP tests by adding t.Cleanup to close client connections. Additionally, it rewrites the TestXOAuth2 function to include more detailed sub-tests, enhancing test granularity and readability. --- smtp/smtp_test.go | 242 +++++++++++++++++++++++++++------------------- 1 file changed, 145 insertions(+), 97 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index efc209d..cf3760c 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -20,6 +20,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "net" "strings" "sync/atomic" @@ -382,6 +383,11 @@ func TestPlainAuth(t *testing.T) { if err != nil { t.Fatalf("failed to connect to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) if err = client.Auth(auth); err != nil { t.Errorf("failed to authenticate to test server: %s", err) } @@ -530,6 +536,11 @@ func TestPlainAuth_noEnc(t *testing.T) { if err != nil { t.Fatalf("failed to connect to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) if err = client.Auth(auth); err != nil { t.Errorf("failed to authenticate to test server: %s", err) } @@ -674,6 +685,11 @@ func TestLoginAuth(t *testing.T) { if err != nil { t.Fatalf("failed to connect to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) if err = client.Auth(auth); err != nil { t.Errorf("failed to authenticate to test server: %s", err) } @@ -815,6 +831,11 @@ func TestLoginAuth_noEnc(t *testing.T) { if err != nil { t.Fatalf("failed to connect to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) if err = client.Auth(auth); err != nil { t.Errorf("failed to authenticate to test server: %s", err) } @@ -848,98 +869,123 @@ func TestLoginAuth_noEnc(t *testing.T) { }) } +func TestXOAuth2Auth(t *testing.T) { + t.Run("XOAuth2 authentication all steps", func(t *testing.T) { + auth := XOAuth2Auth("user", "token") + proto, toserver, err := auth.Start(&ServerInfo{Name: "servername", TLS: true}) + if err != nil { + t.Fatalf("failed to start XOAuth2 authentication: %s", err) + } + if proto != "XOAUTH2" { + t.Errorf("expected protocol to be XOAUTH2, got: %q", proto) + } + expected := []byte("user=user\x01auth=Bearer token\x01\x01") + if !bytes.Equal(expected, toserver) { + t.Errorf("expected server response to be: %q, got: %q", expected, toserver) + } + resp, err := auth.Next([]byte("nonsense"), true) + if err != nil { + t.Errorf("failed on first server challange: %s", err) + } + if !bytes.Equal([]byte(""), resp) { + t.Errorf("expected server response to be empty, got: %q", resp) + } + resp, err = auth.Next([]byte("nonsense"), false) + if err != nil { + t.Errorf("failed on first server challange: %s", err) + } + }) + t.Run("XOAuth2 succeeds with faker", func(t *testing.T) { + server := []string{ + "220 Fake server ready ESMTP", + "250-fake.server", + "250-AUTH XOAUTH2", + "250 8BITMIME", + "235 2.7.0 Accepted", + } + var wrote strings.Builder + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(strings.Join(server, "\r\n")), + &wrote, + } + client, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("failed to create client on faker server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + + auth := XOAuth2Auth("user", "token") + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to faker server: %s", err) + } + + // the Next method returns a nil response. It must not be sent. + // The client request must end with the authentication. + if !strings.HasSuffix(wrote.String(), "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n") { + t.Fatalf("got %q; want AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n", wrote.String()) + } + }) + t.Run("XOAuth2 fails with faker", func(t *testing.T) { + serverResp := []string{ + "220 Fake server ready ESMTP", + "250-fake.server", + "250-AUTH XOAUTH2", + "250 8BITMIME", + "334 eyJzdGF0dXMiOiI0MDAiLCJzY2hlbWVzIjoiQmVhcmVyIiwic2NvcGUiOiJodHRwczovL21haWwuZ29vZ2xlLmNvbS8ifQ==", + "535 5.7.8 Username and Password not accepted", + "221 2.0.0 closing connection", + } + var wrote strings.Builder + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(strings.Join(serverResp, "\r\n")), + &wrote, + } + client, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("failed to create client on faker server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + + auth := XOAuth2Auth("user", "token") + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + resp := strings.Split(wrote.String(), "\r\n") + if len(resp) != 5 { + t.Fatalf("unexpected number of client requests got %d; want 5", len(resp)) + } + if resp[1] != "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=" { + t.Fatalf("got %q; want AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=", resp[1]) + } + // the Next method returns an empty response. It must be sent + if resp[2] != "" { + t.Fatalf("got %q; want empty response", resp[2]) + } + }) +} + /* -func TestXOAuth2OK(t *testing.T) { - server := []string{ - "220 Fake server ready ESMTP", - "250-fake.server", - "250-AUTH XOAUTH2", - "250 8BITMIME", - "235 2.7.0 Accepted", - } - var wrote strings.Builder - var fake faker - fake.ReadWriter = struct { - io.Reader - io.Writer - }{ - strings.NewReader(strings.Join(server, "\r\n")), - &wrote, - } - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - defer func() { - if err := c.Close(); err != nil { - t.Errorf("failed to close client: %s", err) - } - }() - - auth := XOAuth2Auth("user", "token") - err = c.Auth(auth) - if err != nil { - t.Fatalf("XOAuth2 error: %v", err) - } - // the Next method returns a nil response. It must not be sent. - // The client request must end with the authentication. - if !strings.HasSuffix(wrote.String(), "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n") { - t.Fatalf("got %q; want AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n", wrote.String()) - } -} - -func TestXOAuth2Error(t *testing.T) { - serverResp := []string{ - "220 Fake server ready ESMTP", - "250-fake.server", - "250-AUTH XOAUTH2", - "250 8BITMIME", - "334 eyJzdGF0dXMiOiI0MDAiLCJzY2hlbWVzIjoiQmVhcmVyIiwic2NvcGUiOiJodHRwczovL21haWwuZ29vZ2xlLmNvbS8ifQ==", - "535 5.7.8 Username and Password not accepted", - "221 2.0.0 closing connection", - } - var wrote strings.Builder - var fake faker - fake.ReadWriter = struct { - io.Reader - io.Writer - }{ - strings.NewReader(strings.Join(serverResp, "\r\n")), - &wrote, - } - - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - defer func() { - if err := c.Close(); err != nil { - t.Errorf("failed to close client: %s", err) - } - }() - - auth := XOAuth2Auth("user", "token") - err = c.Auth(auth) - if err == nil { - t.Fatal("expected auth error, got nil") - } - client := strings.Split(wrote.String(), "\r\n") - if len(client) != 5 { - t.Fatalf("unexpected number of client requests got %d; want 5", len(client)) - } - if client[1] != "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=" { - t.Fatalf("got %q; want AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=", client[1]) - } - // the Next method returns an empty response. It must be sent - if client[2] != "" { - t.Fatalf("got %q; want empty response", client[2]) - } -} func TestAuthSCRAMSHA1_OK(t *testing.T) { hostname := "127.0.0.1" @@ -1217,17 +1263,6 @@ func (toServerEmptyAuth) Next(_ []byte, _ bool) (toServer []byte, err error) { panic("unexpected call") } -type faker struct { - io.ReadWriter -} - -func (f faker) Close() error { return nil } -func (f faker) LocalAddr() net.Addr { return nil } -func (f faker) RemoteAddr() net.Addr { return nil } -func (f faker) SetDeadline(time.Time) error { return nil } -func (f faker) SetReadDeadline(time.Time) error { return nil } -func (f faker) SetWriteDeadline(time.Time) error { return nil } - func TestBasic(t *testing.T) { server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") @@ -3016,6 +3051,19 @@ func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash) */ + +// faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. +type faker struct { + io.ReadWriter +} + +func (f faker) Close() error { return nil } +func (f faker) LocalAddr() net.Addr { return nil } +func (f faker) RemoteAddr() net.Addr { return nil } +func (f faker) SetDeadline(time.Time) error { return nil } +func (f faker) SetReadDeadline(time.Time) error { return nil } +func (f faker) SetWriteDeadline(time.Time) error { return nil } + // testingKey replaces the substring "TESTING KEY" with "PRIVATE KEY" in the given string s. func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } From c656226fd3a0b786dbe4fe859f42215882952efd Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 8 Nov 2024 15:51:17 +0100 Subject: [PATCH 09/45] Add XOAuth2 authentication tests to SMTP package Introduces two tests for XOAuth2 authentication in the SMTP package. The first test ensures successful authentication with valid credentials, while the second test verifies that authentication fails with incorrect settings. --- smtp/smtp_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index cf3760c..5846e7f 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -978,6 +978,64 @@ func TestXOAuth2Auth(t *testing.T) { t.Fatalf("got %q; want empty response", resp[2]) } }) + t.Run("XOAuth2 authentication on test server succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH XOAUTH2\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort}, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := XOAuth2Auth("user", "token") + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run("XOAuth2 authentication on test server fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH XOAUTH2\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort}, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := XOAuth2Auth("user", "token") + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + }) } /* From d4c6cb506c8ef8fc3828b16e107e5535a3fe860b Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 8 Nov 2024 16:53:09 +0100 Subject: [PATCH 10/45] Add SCRAM authentication tests to smtp package Added comprehensive unit tests for SCRAM-SHA-1, SCRAM-SHA-256, and their PLUS variants. Implemented a test server to simulate various SCRAM authentication scenarios and validate both success and failure cases. --- smtp/smtp_test.go | 352 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 352 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 5846e7f..70f8d04 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -17,15 +17,22 @@ import ( "bufio" "bytes" "context" + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" "crypto/tls" + "encoding/base64" "errors" "fmt" + "hash" "io" "net" "strings" "sync/atomic" "testing" "time" + + "golang.org/x/crypto/pbkdf2" ) const ( @@ -1038,6 +1045,178 @@ func TestXOAuth2Auth(t *testing.T) { }) } +func TestScramAuth(t *testing.T) { + tests := []struct { + name string + tls bool + authString string + hash func() hash.Hash + isPlus bool + }{ + {"SCRAM-SHA-1 (no TLS)", false, "SCRAM-SHA-1", sha1.New, false}, + {"SCRAM-SHA-256 (no TLS)", false, "SCRAM-SHA-256", sha256.New, false}, + {"SCRAM-SHA-1 (with TLS)", true, "SCRAM-SHA-1", sha1.New, false}, + {"SCRAM-SHA-256 (with TLS)", true, "SCRAM-SHA-256", sha256.New, false}, + {"SCRAM-SHA-1-PLUS", true, "SCRAM-SHA-1-PLUS", sha1.New, true}, + {"SCRAM-SHA-256-PLUS", true, "SCRAM-SHA-256-PLUS", sha256.New, true}, + } + for _, tt := range tests { + t.Run(tt.name+" succeeds on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := fmt.Sprintf("250-AUTH %s\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8", tt.authString) + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + TestSCRAM: true, + HashFunc: tt.hash, + FeatureSet: featureSet, + ListenPort: serverPort, + SSLListener: tt.tls, + IsSCRAMPlus: tt.isPlus, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + var client *Client + switch tt.tls { + case true: + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + fmt.Printf("error creating TLS cert: %s", err) + return + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig) + if err != nil { + t.Fatalf("failed to dial TLS server: %v", err) + } + client, err = NewClient(conn, TestServerAddr) + case false: + var err error + client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + } + t.Cleanup(func() { + if err := client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + + var auth Auth + switch tt.authString { + case "SCRAM-SHA-1": + auth = ScramSHA1Auth("username", "password") + case "SCRAM-SHA-256": + auth = ScramSHA256Auth("username", "password") + case "SCRAM-SHA-1-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA1PlusAuth("username", "password", tlsConnState) + case "SCRAM-SHA-256-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA256PlusAuth("username", "password", tlsConnState) + default: + t.Fatalf("unexpected auth string: %s", tt.authString) + } + if err := client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run(tt.name+" fails on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := fmt.Sprintf("250-AUTH %s\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8", tt.authString) + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + TestSCRAM: true, + HashFunc: tt.hash, + FeatureSet: featureSet, + ListenPort: serverPort, + SSLListener: tt.tls, + IsSCRAMPlus: tt.isPlus, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + var client *Client + switch tt.tls { + case true: + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + fmt.Printf("error creating TLS cert: %s", err) + return + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig) + if err != nil { + t.Fatalf("failed to dial TLS server: %v", err) + } + client, err = NewClient(conn, TestServerAddr) + case false: + var err error + client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + } + + var auth Auth + switch tt.authString { + case "SCRAM-SHA-1": + auth = ScramSHA1Auth("invalid", "password") + case "SCRAM-SHA-256": + auth = ScramSHA256Auth("invalid", "password") + case "SCRAM-SHA-1-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA1PlusAuth("invalid", "password", tlsConnState) + case "SCRAM-SHA-256-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA256PlusAuth("invalid", "password", tlsConnState) + default: + t.Fatalf("unexpected auth string: %s", tt.authString) + } + if err := client.Auth(auth); err == nil { + t.Error("expected authentication to fail") + } + }) + } + t.Run("ScramAuth_Next with nonsense parameter", func(t *testing.T) { + auth := ScramSHA1Auth("username", "password") + _, err := auth.Next([]byte("x=nonsense"), true) + if err == nil { + t.Fatal("expected authentication to fail") + } + if !errors.Is(err, ErrUnexpectedServerResponse) { + t.Errorf("expected ErrUnexpectedServerResponse, got %s", err) + } + }) +} + /* @@ -3140,8 +3319,11 @@ type serverProps struct { FeatureSet string ListenPort int SSLListener bool + IsSCRAMPlus bool IsTLS bool SupportDSN bool + TestSCRAM bool + HashFunc func() hash.Hash } // simpleSMTPServer starts a simple TCP server that resonds to SMTP commands. @@ -3279,6 +3461,22 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server writeLine("535 5.7.8 Error: authentication failed") break } + if props.TestSCRAM { + parts := strings.Split(data, " ") + authMechanism := parts[1] + if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" && + authMechanism != "SCRAM-SHA-1-PLUS" && authMechanism != "SCRAM-SHA-256-PLUS" { + writeLine("504 Unrecognized authentication mechanism") + break + } + scram := &testSCRAMSMTP{ + tlsServer: props.IsSCRAMPlus, + h: props.HashFunc, + } + writeLine("334 ") + scram.handleSCRAMAuth(connection) + break + } writeLine("235 2.7.0 Authentication successful") case strings.EqualFold(data, "DATA"): if props.FailOnDataInit { @@ -3348,3 +3546,157 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server } } } + +// testSCRAMSMTP represents a part of the test server for SCRAM-based SMTP authentication. +// It does not do any acutal computation of the challenges but verifies that the expected +// fields are present. We have actual real authentication tests for all SCRAM modes in the +// go-mail client_test.go +type testSCRAMSMTP struct { + authMechanism string + nonce string + h func() hash.Hash + tlsServer bool +} + +func (s *testSCRAMSMTP) handleSCRAMAuth(conn net.Conn) { + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + writeLine := func(data string) error { + _, err := writer.WriteString(data + "\r\n") + if err != nil { + return fmt.Errorf("unable to write line: %w", err) + } + return writer.Flush() + } + var authMsg string + + data, err := reader.ReadString('\n') + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + data = strings.TrimSpace(data) + decodedMessage, err := base64.StdEncoding.DecodeString(data) + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + splits := strings.Split(string(decodedMessage), ",") + if len(splits) != 4 { + _ = writeLine("535 Authentication failed - expected 4 parts") + return + } + if !s.tlsServer && splits[0] != "n" { + _ = writeLine("535 Authentication failed - expected n to be in the first part") + return + } + if s.tlsServer && !strings.HasPrefix(splits[0], "p=") { + _ = writeLine("535 Authentication failed - expected p= to be in the first part") + return + } + if splits[2] != "n=username" { + _ = writeLine("535 Authentication failed - expected n=username to be in the third part") + return + } + if !strings.HasPrefix(splits[3], "r=") { + _ = writeLine("535 Authentication failed - expected r= to be in the fourth part") + return + } + authMsg = splits[2] + "," + splits[3] + + clientNonce := s.extractNonce(string(decodedMessage)) + if clientNonce == "" { + _ = writeLine("535 Authentication failed") + return + } + + s.nonce = clientNonce + "server_nonce" + serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=4096", s.nonce, + base64.StdEncoding.EncodeToString([]byte("salt"))) + _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage)))) + authMsg = authMsg + "," + serverFirstMessage + + data, err = reader.ReadString('\n') + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + data = strings.TrimSpace(data) + decodedFinalMessage, err := base64.StdEncoding.DecodeString(data) + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + splits = strings.Split(string(decodedFinalMessage), ",") + + if !s.tlsServer && splits[0] != "c=biws" { + _ = writeLine("535 Authentication failed - expected c=biws to be in the first part") + return + } + if s.tlsServer { + if !strings.HasPrefix(splits[0], "c=") { + _ = writeLine("535 Authentication failed - expected c= to be in the first part") + return + } + channelBind, err := base64.StdEncoding.DecodeString(splits[0][2:]) + if err != nil { + _ = writeLine("535 Authentication failed - base64 channel bind is not valid - " + err.Error()) + return + } + if !strings.HasPrefix(string(channelBind), "p=") { + _ = writeLine("535 Authentication failed - expected channel binding to start with p=-") + return + } + cbType := string(channelBind[2:]) + if !strings.HasPrefix(cbType, "tls-unique") && !strings.HasPrefix(cbType, "tls-exporter") { + _ = writeLine("535 Authentication failed - expected channel binding type tls-unique or tls-exporter") + return + } + } + + if !strings.HasPrefix(splits[1], "r=") { + _ = writeLine("535 Authentication failed - expected r to be in the second part") + return + } + if !strings.Contains(splits[1], "server_nonce") { + _ = writeLine("535 Authentication failed - expected server_nonce to be in the second part") + return + } + if !strings.HasPrefix(splits[2], "p=") { + _ = writeLine("535 Authentication failed - expected p to be in the third part") + return + } + + authMsg = authMsg + "," + splits[0] + "," + splits[1] + saltedPwd := pbkdf2.Key([]byte("password"), []byte("salt"), 4096, s.h().Size(), s.h) + mac := hmac.New(s.h, saltedPwd) + mac.Write([]byte("Server Key")) + skey := mac.Sum(nil) + mac.Reset() + + mac = hmac.New(s.h, skey) + mac.Write([]byte(authMsg)) + ssig := mac.Sum(nil) + mac.Reset() + + serverFinalMessage := fmt.Sprintf("v=%s", base64.StdEncoding.EncodeToString(ssig)) + _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFinalMessage)))) + + _, err = reader.ReadString('\n') + if err != nil { + _ = writeLine("535 Authentication failed") + return + } + + _ = writeLine("235 Authentication successful") +} + +func (s *testSCRAMSMTP) extractNonce(message string) string { + parts := strings.Split(message, ",") + for _, part := range parts { + if strings.HasPrefix(part, "r=") { + return part[2:] + } + } + return "" +} From d6f256c29ed9e11a0adb89ed2509cb4927ffb673 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 8 Nov 2024 22:30:46 +0100 Subject: [PATCH 11/45] Fix typo in error message in normalizeString function Corrected the spelling of "failed" in the error handling branch of the normalizeString function within smtp/auth_scram.go. This change addresses a minor typographical error to ensure the error message is clear and accurate. --- smtp/auth_scram.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smtp/auth_scram.go b/smtp/auth_scram.go index a21aef5..e27831b 100644 --- a/smtp/auth_scram.go +++ b/smtp/auth_scram.go @@ -308,7 +308,7 @@ func (a *scramAuth) normalizeUsername() (string, error) { func (a *scramAuth) normalizeString(s string) (string, error) { s, err := precis.OpaqueString.String(s) if err != nil { - return "", fmt.Errorf("failled to normalize string: %w", err) + return "", fmt.Errorf("failed to normalize string: %w", err) } if s == "" { return "", errors.New("normalized string is empty") From a1efa1a1caf81e78cda51b2a85ab377dcd5a812f Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 8 Nov 2024 22:44:10 +0100 Subject: [PATCH 12/45] Remove redundant empty string check in SCRAM normalization The existing check for an empty normalized string is unnecessary because the OpaqueString profile in precis already throws an error if an empty string is returned: https://cs.opensource.google/go/x/text/+/master:secure/precis/profiles.go;l=66 --- smtp/auth_scram.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/smtp/auth_scram.go b/smtp/auth_scram.go index e27831b..23915ae 100644 --- a/smtp/auth_scram.go +++ b/smtp/auth_scram.go @@ -310,8 +310,5 @@ func (a *scramAuth) normalizeString(s string) (string, error) { if err != nil { return "", fmt.Errorf("failed to normalize string: %w", err) } - if s == "" { - return "", errors.New("normalized string is empty") - } return s, nil } From c6d416f142e73c2ade50b467f48b4c99a883ad03 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 8 Nov 2024 23:05:31 +0100 Subject: [PATCH 13/45] Fix typo in the tls-unique channel binding comment Corrected "crypto/tl" to "crypto/tls" in the comment for better clarity and accuracy. This typo fix ensures that the code comments correctly reference the relevant Go package. --- smtp/auth_scram.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smtp/auth_scram.go b/smtp/auth_scram.go index 23915ae..03de54c 100644 --- a/smtp/auth_scram.go +++ b/smtp/auth_scram.go @@ -154,7 +154,7 @@ func (a *scramAuth) initialClientMessage() ([]byte, error) { connState := a.tlsConnState bindData := connState.TLSUnique - // crypto/tl: no tls-unique channel binding value for this tls connection, possibly due to missing + // crypto/tls: no tls-unique channel binding value for this tls connection, possibly due to missing // extended master key support and/or resumed connection // RFC9266:122 tls-unique not defined for tls 1.3 and later if bindData == nil || connState.Version >= tls.VersionTLS13 { From 92ab51b13d28ad9266b9d7ecf55ceddf77cda49f Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sat, 9 Nov 2024 00:06:33 +0100 Subject: [PATCH 14/45] Add comprehensive tests for scramAuth error handling Introduce new test cases to validate the error handling of the scramAuth methods. These tests cover scenarios such as invalid characters in usernames, empty inputs, and edge cases like broken rand.Reader. --- smtp/smtp_test.go | 186 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 70f8d04..238176a 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "crypto/hmac" + "crypto/rand" "crypto/sha1" "crypto/sha256" "crypto/tls" @@ -1217,6 +1218,182 @@ func TestScramAuth(t *testing.T) { }) } +func TestScramAuth_normalizeString(t *testing.T) { + t.Run("normalizeString with invalid input should fail", func(t *testing.T) { + auth := scramAuth{} + value := "\u0000example\uFFFEstring\u001F" + _, err := auth.normalizeString(value) + if err == nil { + t.Fatal("normalizeString should fail on disallowed runes") + } + if !strings.Contains(err.Error(), "precis: disallowed rune encountered") { + t.Errorf("expected error to be %q, got %q", "precis: disallowed rune encountered", err) + } + }) + t.Run("normalizeString on empty string should fail", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.normalizeString("") + if err == nil { + t.Error("normalizeString should fail on disallowed runes") + } + if !strings.Contains(err.Error(), "precis: transformation resulted in empty string") { + t.Errorf("expected error to be %q, got %q", "precis: transformation resulted in empty string", err) + } + }) + t.Run("normalizeUsername with invalid input should fail", func(t *testing.T) { + auth := scramAuth{username: "\u0000example\uFFFEstring\u001F"} + _, err := auth.normalizeUsername() + if err == nil { + t.Error("normalizeUsername should fail on disallowed runes") + } + if !strings.Contains(err.Error(), "precis: disallowed rune encountered") { + t.Errorf("expected error to be %q, got %q", "precis: disallowed rune encountered", err) + } + }) + t.Run("normalizeUsername with empty input should fail", func(t *testing.T) { + auth := scramAuth{username: ""} + _, err := auth.normalizeUsername() + if err == nil { + t.Error("normalizeUsername should fail on empty input") + } + if !strings.Contains(err.Error(), "precis: transformation resulted in empty string") { + t.Errorf("expected error to be %q, got %q", "precis: transformation resulted in empty string", err) + } + }) +} + +func TestScramAuth_initialClientMessage(t *testing.T) { + t.Run("initialClientMessage with invalid username should fail", func(t *testing.T) { + auth := scramAuth{username: "\u0000example\uFFFEstring\u001F"} + _, err := auth.initialClientMessage() + if err == nil { + t.Error("initialClientMessage should fail on disallowed runes") + } + if !strings.Contains(err.Error(), "precis: disallowed rune encountered") { + t.Errorf("expected error to be %q, got %q", "precis: disallowed rune encountered", err) + } + }) + t.Run("initialClientMessage with empty username should fail", func(t *testing.T) { + auth := scramAuth{username: ""} + _, err := auth.initialClientMessage() + if err == nil { + t.Error("initialClientMessage should fail on empty username") + } + if !strings.Contains(err.Error(), "precis: transformation resulted in empty string") { + t.Errorf("expected error to be %q, got %q", "precis: transformation resulted in empty string", err) + } + }) + t.Run("initialClientMessage fails on broken rand.Reader", func(t *testing.T) { + defaultRandReader := rand.Reader + t.Cleanup(func() { rand.Reader = defaultRandReader }) + rand.Reader = &randReader{} + auth := scramAuth{username: "username"} + _, err := auth.initialClientMessage() + if err == nil { + t.Error("initialClientMessage should fail with broken rand.Reader") + } + if !strings.Contains(err.Error(), "unable to generate client secret: broken reader") { + t.Errorf("expected error to be %q, got %q", "unable to generate client secret: broken reader", err) + } + }) +} + +func TestScramAuth_handleServerFirstResponse(t *testing.T) { + t.Run("handleServerFirstResponse fails if not at least 3 parts", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.handleServerFirstResponse([]byte("r=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "not enough fields in the first server response" + if !strings.EqualFold(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with first part does not start with r=", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.handleServerFirstResponse([]byte("x=0,y=0,z=0,r=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "first part of the server response does not start with r=" + if !strings.EqualFold(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with second part does not start with s=", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.handleServerFirstResponse([]byte("r=0,x=0,y=0,z=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "second part of the server response does not start with s=" + if !strings.EqualFold(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with third part does not start with i=", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.handleServerFirstResponse([]byte("r=0,s=0,y=0,z=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "third part of the server response does not start with i=" + if !strings.EqualFold(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with empty nonce", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.handleServerFirstResponse([]byte("r=,s=0,i=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "server nonce does not start with our nonce" + if !strings.EqualFold(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with non-base64 nonce", func(t *testing.T) { + auth := scramAuth{nonce: []byte("Test123")} + _, err := auth.handleServerFirstResponse([]byte("r=Test123,s=0,i=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "illegal base64 data at input byte 0" + if !strings.Contains(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with non-number iterations", func(t *testing.T) { + auth := scramAuth{nonce: []byte("VGVzdDEyMw==")} + _, err := auth.handleServerFirstResponse([]byte("r=VGVzdDEyMw==,s=VGVzdDEyMw==,i=abc")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := `invalid iterations: strconv.Atoi: parsing "abc": invalid syntax` + if !strings.Contains(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with invalid password runes", func(t *testing.T) { + auth := scramAuth{ + nonce: []byte("VGVzdDEyMw=="), + username: "username", + password: "\u0000example\uFFFEstring\u001F", + } + _, err := auth.handleServerFirstResponse([]byte("r=VGVzdDEyMw==,s=VGVzdDEyMw==,i=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := `unable to normalize password: failed to normalize string: precis: disallowed rune encountered` + if !strings.Contains(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + +} + /* @@ -3700,3 +3877,12 @@ func (s *testSCRAMSMTP) extractNonce(message string) string { } return "" } + +// randReader is type that satisfies the io.Reader interface. It can fail on a specific read +// operations and is therefore useful to test consecutive reads with errors +type randReader struct{} + +// Read implements the io.Reader interface for the randReader type +func (r *randReader) Read([]byte) (int, error) { + return 0, errors.New("broken reader") +} From cefaa0d4eeb6fff849fba7771ce126c535981e04 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sat, 9 Nov 2024 13:44:14 +0100 Subject: [PATCH 15/45] Refactor SMTP test cases for improved clarity and coverage Consolidate and organize SMTP test cases by removing obsolete tests and adding focused, detailed tests for CRAM-MD5 and new client behaviors. This ensures clearer test structure and better coverage of edge cases. --- smtp/smtp_test.go | 769 ++++++++++++++-------------------------------- 1 file changed, 229 insertions(+), 540 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 238176a..da7ad8b 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1391,252 +1391,204 @@ func TestScramAuth_handleServerFirstResponse(t *testing.T) { t.Errorf("expected error to be %q, got %q", expectedErr, err) } }) - } -/* +func TestCRAMMD5Auth(t *testing.T) { + t.Run("CRAM-MD5 on test server succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH CRAM-MD5\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + auth := CRAMMD5Auth("username", "password") + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Auth(auth); err != nil { + t.Errorf("failed to auth to test server: %s", err) + } + }) + t.Run("CRAM-MD5 on test server fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH CRAM-MD5\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) - - - - -func TestAuthSCRAMSHA1_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2585" - - go func() { - startSMTPServer(false, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA1Auth("username", "password")); err != nil { - t.Errorf("failed to authenticate: %v", err) - } + auth := CRAMMD5Auth("username", "password") + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Error("auth should fail on test server") + } + }) } -func TestAuthSCRAMSHA256_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2586" +func TestNewClient(t *testing.T) { + t.Run("new client via Dial succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) - go func() { - startSMTPServer(false, hostname, port, sha256.New) - }() - time.Sleep(time.Millisecond * 500) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to create client: %s", err) + } + if err := client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + t.Run("new client via Dial fails on server not started", func(t *testing.T) { + _, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, 64000)) + if err == nil { + t.Error("dial on non-existant server should fail") + } + }) + t.Run("new client fails on server not available", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnDial: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA256Auth("username", "password")); err != nil { - t.Errorf("failed to authenticate: %v", err) - } + _, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err == nil { + t.Error("connection to non-available server should fail") + } + }) + t.Run("new client fails on faker that fails on close", func(t *testing.T) { + server := "442 service not available\r\n" + var wrote strings.Builder + var fake faker + fake.failOnClose = true + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(server), + &wrote, + } + _, err := NewClient(fake, "faker.host") + if err == nil { + t.Error("connection to non-available server should fail on close") + } + }) } -func TestAuthSCRAMSHA1PLUS_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2590" +func TestClient_hello(t *testing.T) { + t.Run("client fails on EHLO but not on HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) - go func() { - startSMTPServer(true, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA1PlusAuth("username", "password", &tlsConnState)); err != nil { - t.Errorf("failed to authenticate: %v", err) - } + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.hello(); err == nil { + t.Error("helo should fail on test server") + } + }) } -func TestAuthSCRAMSHA256PLUS_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2591" +func TestClient_Hello(t *testing.T) { + t.Run("normal client HELO/EHLO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) - go func() { - startSMTPServer(true, hostname, port, sha256.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA256PlusAuth("username", "password", &tlsConnState)); err != nil { - t.Errorf("failed to authenticate: %v", err) - } -} - -func TestAuthSCRAMSHA1_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2587" - - go func() { - startSMTPServer(false, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA1Auth("username", "invalid")); err == nil { - t.Errorf("expected auth error, got nil") - } -} - -func TestAuthSCRAMSHA256_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2588" - - go func() { - startSMTPServer(false, hostname, port, sha256.New) - }() - time.Sleep(time.Millisecond * 500) - - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA256Auth("username", "invalid")); err == nil { - t.Errorf("expected auth error, got nil") - } -} - -func TestAuthSCRAMSHA1PLUS_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2592" - - go func() { - startSMTPServer(true, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA1PlusAuth("username", "invalid", &tlsConnState)); err == nil { - t.Errorf("expected auth error, got nil") - } -} - -func TestAuthSCRAMSHA256PLUS_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2593" - - go func() { - startSMTPServer(true, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA256PlusAuth("username", "invalid", &tlsConnState)); err == nil { - t.Errorf("expected auth error, got nil") - } + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Hello(TestServerAddr); err != nil { + t.Errorf("failed to send HELO/EHLO to test server: %s", err) + } + }) } // Issue 17794: don't send a trailing space on AUTH command when there's no password. -func TestClientAuthTrimSpace(t *testing.T) { +func TestClient_Auth_trimSpace(t *testing.T) { server := "220 hello world\r\n" + "200 some more" var wrote strings.Builder @@ -1655,7 +1607,7 @@ func TestClientAuthTrimSpace(t *testing.T) { c.tls = true c.didHello = true _ = c.Auth(toServerEmptyAuth{}) - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Errorf("close failed: %s", err) } if got, want := wrote.String(), "AUTH FOOAUTH\r\n*\r\nQUIT\r\n"; got != want { @@ -1663,19 +1615,15 @@ func TestClientAuthTrimSpace(t *testing.T) { } } -// toServerEmptyAuth is an implementation of Auth that only implements -// the Start method, and returns "FOOAUTH", nil, nil. Notably, it returns -// zero bytes for "toServer" so we can test that we don't send spaces at -// the end of the line. See TestClientAuthTrimSpace. -type toServerEmptyAuth struct{} +/* + + + + + + -func (toServerEmptyAuth) Start(_ *ServerInfo) (proto string, toServer []byte, err error) { - return "FOOAUTH", nil, nil -} -func (toServerEmptyAuth) Next(_ []byte, _ bool) (toServer []byte, err error) { - panic("unexpected call") -} func TestBasic(t *testing.T) { server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") @@ -3158,46 +3106,6 @@ func sendMail(hostPort string) error { return SendMail(hostPort, nil, from, to, []byte("Subject: test\n\nhowdy!")) } -// localhostCert is a PEM-encoded TLS cert generated from src/crypto/tls: -// -// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com \ -// --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h -var localhostCert = []byte(` ------BEGIN CERTIFICATE----- -MIICFDCCAX2gAwIBAgIRAK0xjnaPuNDSreeXb+z+0u4wDQYJKoZIhvcNAQELBQAw -EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2 -MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw -gYkCgYEA0nFbQQuOWsjbGtejcpWz153OlziZM4bVjJ9jYruNw5n2Ry6uYQAffhqa -JOInCmmcVe2siJglsyH9aRh6vKiobBbIUXXUU1ABd56ebAzlt0LobLlx7pZEMy30 -LqIi9E6zmL3YvdGzpYlkFRnRrqwEtWYbGBf3znO250S56CCWH2UCAwEAAaNoMGYw -DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF -MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA -AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAbZtDS2dVuBYvb+MnolWnCNqvw1w5Gtgi -NmvQQPOMgM3m+oQSCPRTNGSg25e1Qbo7bgQDv8ZTnq8FgOJ/rbkyERw2JckkHpD4 -n4qcK27WkEDBtQFlPihIM8hLIuzWoi/9wygiElTy/tVL3y7fGCvY2/k1KBthtZGF -tN8URjVmyEo= ------END CERTIFICATE-----`) - -// localhostKey is the private key for localhostCert. -var localhostKey = []byte(testingKey(` ------BEGIN RSA TESTING KEY----- -MIICXgIBAAKBgQDScVtBC45ayNsa16NylbPXnc6XOJkzhtWMn2Niu43DmfZHLq5h -AB9+Gpok4icKaZxV7ayImCWzIf1pGHq8qKhsFshRddRTUAF3np5sDOW3QuhsuXHu -lkQzLfQuoiL0TrOYvdi90bOliWQVGdGurAS1ZhsYF/fOc7bnRLnoIJYfZQIDAQAB -AoGBAMst7OgpKyFV6c3JwyI/jWqxDySL3caU+RuTTBaodKAUx2ZEmNJIlx9eudLA -kucHvoxsM/eRxlxkhdFxdBcwU6J+zqooTnhu/FE3jhrT1lPrbhfGhyKnUrB0KKMM -VY3IQZyiehpxaeXAwoAou6TbWoTpl9t8ImAqAMY8hlULCUqlAkEA+9+Ry5FSYK/m -542LujIcCaIGoG1/Te6Sxr3hsPagKC2rH20rDLqXwEedSFOpSS0vpzlPAzy/6Rbb -PHTJUhNdwwJBANXkA+TkMdbJI5do9/mn//U0LfrCR9NkcoYohxfKz8JuhgRQxzF2 -6jpo3q7CdTuuRixLWVfeJzcrAyNrVcBq87cCQFkTCtOMNC7fZnCTPUv+9q1tcJyB -vNjJu3yvoEZeIeuzouX9TJE21/33FaeDdsXbRhQEj23cqR38qFHsF1qAYNMCQQDP -QXLEiJoClkR2orAmqjPLVhR3t2oB3INcnEjLNSq8LHyQEfXyaFfu4U9l5+fRPL2i -jiC0k/9L5dHUsF0XZothAkEA23ddgRs+Id/HxtojqqUT27B8MT/IGNrYsp4DvS/c -qgkeluku4GjxRlDMBuXk94xOBEinUs+p/hwP1Alll80Tpg== ------END RSA TESTING KEY-----`)) - -func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } - var flaky = flag.Bool("flaky", false, "run known-flaky tests too") func SkipFlaky(t testing.TB, issue int) { @@ -3207,271 +3115,22 @@ func SkipFlaky(t testing.TB, issue int) { } } -// testSCRAMSMTPServer represents a test server for SCRAM-based SMTP authentication. -// It does not do any acutal computation of the challenges but verifies that the expected -// fields are present. We have actual real authentication tests for all SCRAM modes in the -// go-mail client_test.go -type testSCRAMSMTPServer struct { - authMechanism string - nonce string - hostname string - port string - tlsServer bool - h func() hash.Hash -} - -func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) { - defer func() { - _ = conn.Close() - }() - - reader := bufio.NewReader(conn) - writer := bufio.NewWriter(conn) - writeLine := func(data string) error { - _, err := writer.WriteString(data + "\r\n") - if err != nil { - return fmt.Errorf("unable to write line: %w", err) - } - return writer.Flush() - } - writeOK := func() { - _ = writeLine("250 2.0.0 OK") - } - - if err := writeLine("220 go-mail test server ready ESMTP"); err != nil { - return - } - - data, err := reader.ReadString('\n') - if err != nil { - return - } - data = strings.TrimSpace(data) - if strings.HasPrefix(data, "EHLO") { - _ = writeLine(fmt.Sprintf("250-%s", s.hostname)) - _ = writeLine("250-AUTH SCRAM-SHA-1 SCRAM-SHA-256") - writeOK() - } else { - _ = writeLine("500 Invalid command") - return - } - - for { - data, err = reader.ReadString('\n') - if err != nil { - fmt.Printf("failed to read data: %v", err) - } - data = strings.TrimSpace(data) - if strings.HasPrefix(data, "AUTH") { - parts := strings.Split(data, " ") - if len(parts) < 2 { - _ = writeLine("500 Syntax error") - return - } - - authMechanism := parts[1] - if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" && - authMechanism != "SCRAM-SHA-1-PLUS" && authMechanism != "SCRAM-SHA-256-PLUS" { - _ = writeLine("504 Unrecognized authentication mechanism") - return - } - s.authMechanism = authMechanism - _ = writeLine("334 ") - s.handleSCRAMAuth(conn) - return - } else { - _ = writeLine("500 Invalid command") - } - } -} - -func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { - reader := bufio.NewReader(conn) - writer := bufio.NewWriter(conn) - writeLine := func(data string) error { - _, err := writer.WriteString(data + "\r\n") - if err != nil { - return fmt.Errorf("unable to write line: %w", err) - } - return writer.Flush() - } - var authMsg string - - data, err := reader.ReadString('\n') - if err != nil { - _ = writeLine("535 Authentication failed") - return - } - data = strings.TrimSpace(data) - decodedMessage, err := base64.StdEncoding.DecodeString(data) - if err != nil { - _ = writeLine("535 Authentication failed") - return - } - splits := strings.Split(string(decodedMessage), ",") - if len(splits) != 4 { - _ = writeLine("535 Authentication failed - expected 4 parts") - return - } - if !s.tlsServer && splits[0] != "n" { - _ = writeLine("535 Authentication failed - expected n to be in the first part") - return - } - if s.tlsServer && !strings.HasPrefix(splits[0], "p=") { - _ = writeLine("535 Authentication failed - expected p= to be in the first part") - return - } - if splits[2] != "n=username" { - _ = writeLine("535 Authentication failed - expected n=username to be in the third part") - return - } - if !strings.HasPrefix(splits[3], "r=") { - _ = writeLine("535 Authentication failed - expected r= to be in the fourth part") - return - } - authMsg = splits[2] + "," + splits[3] - - clientNonce := s.extractNonce(string(decodedMessage)) - if clientNonce == "" { - _ = writeLine("535 Authentication failed") - return - } - - s.nonce = clientNonce + "server_nonce" - serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=4096", s.nonce, - base64.StdEncoding.EncodeToString([]byte("salt"))) - _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage)))) - authMsg = authMsg + "," + serverFirstMessage - - data, err = reader.ReadString('\n') - if err != nil { - _ = writeLine("535 Authentication failed") - return - } - data = strings.TrimSpace(data) - decodedFinalMessage, err := base64.StdEncoding.DecodeString(data) - if err != nil { - _ = writeLine("535 Authentication failed") - return - } - splits = strings.Split(string(decodedFinalMessage), ",") - - if !s.tlsServer && splits[0] != "c=biws" { - _ = writeLine("535 Authentication failed - expected c=biws to be in the first part") - return - } - if s.tlsServer { - if !strings.HasPrefix(splits[0], "c=") { - _ = writeLine("535 Authentication failed - expected c= to be in the first part") - return - } - channelBind, err := base64.StdEncoding.DecodeString(splits[0][2:]) - if err != nil { - _ = writeLine("535 Authentication failed - base64 channel bind is not valid - " + err.Error()) - return - } - if !strings.HasPrefix(string(channelBind), "p=") { - _ = writeLine("535 Authentication failed - expected channel binding to start with p=-") - return - } - cbType := string(channelBind[2:]) - if !strings.HasPrefix(cbType, "tls-unique") && !strings.HasPrefix(cbType, "tls-exporter") { - _ = writeLine("535 Authentication failed - expected channel binding type tls-unique or tls-exporter") - return - } - } - - if !strings.HasPrefix(splits[1], "r=") { - _ = writeLine("535 Authentication failed - expected r to be in the second part") - return - } - if !strings.Contains(splits[1], "server_nonce") { - _ = writeLine("535 Authentication failed - expected server_nonce to be in the second part") - return - } - if !strings.HasPrefix(splits[2], "p=") { - _ = writeLine("535 Authentication failed - expected p to be in the third part") - return - } - - authMsg = authMsg + "," + splits[0] + "," + splits[1] - saltedPwd := pbkdf2.Key([]byte("password"), []byte("salt"), 4096, s.h().Size(), s.h) - mac := hmac.New(s.h, saltedPwd) - mac.Write([]byte("Server Key")) - skey := mac.Sum(nil) - mac.Reset() - - mac = hmac.New(s.h, skey) - mac.Write([]byte(authMsg)) - ssig := mac.Sum(nil) - mac.Reset() - - serverFinalMessage := fmt.Sprintf("v=%s", base64.StdEncoding.EncodeToString(ssig)) - _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFinalMessage)))) - - _, err = reader.ReadString('\n') - if err != nil { - _ = writeLine("535 Authentication failed") - return - } - - _ = writeLine("235 Authentication successful") -} - -func (s *testSCRAMSMTPServer) extractNonce(message string) string { - parts := strings.Split(message, ",") - for _, part := range parts { - if strings.HasPrefix(part, "r=") { - return part[2:] - } - } - return "" -} - -func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash) { - server := &testSCRAMSMTPServer{ - hostname: hostname, - port: port, - tlsServer: tlsServer, - h: h, - } - listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - fmt.Printf("Failed to start SMTP server: %v", err) - } - defer func() { - _ = listener.Close() - }() - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}} - - for { - conn, err := listener.Accept() - if err != nil { - fmt.Printf("Failed to accept connection: %v", err) - continue - } - if server.tlsServer { - conn = tls.Server(conn, &tlsConfig) - } - go server.handleConnection(conn) - } -} - */ // faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter + failOnRead bool + failOnClose bool } -func (f faker) Close() error { return nil } +func (f faker) Close() error { + if f.failOnClose { + return fmt.Errorf("faker: failed to close connection") + } + return nil +} func (f faker) LocalAddr() net.Addr { return nil } func (f faker) RemoteAddr() net.Addr { return nil } func (f faker) SetDeadline(time.Time) error { return nil } @@ -3486,6 +3145,8 @@ type serverProps struct { FailOnAuth bool FailOnDataInit bool FailOnDataClose bool + FailOnDial bool + FailOnEhlo bool FailOnHelo bool FailOnMailFrom bool FailOnNoop bool @@ -3582,6 +3243,10 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server } if !props.IsTLS { + if props.FailOnDial { + writeLine("421 4.4.1 Service not available") + return + } writeLine("220 go-mail test server ready ESMTP") } @@ -3595,9 +3260,9 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server var datastring string data = strings.TrimSpace(data) switch { - case strings.HasPrefix(data, "EHLO"), strings.HasPrefix(data, "HELO"): + case strings.HasPrefix(data, "HELO"): if len(strings.Split(data, " ")) != 2 { - writeLine("501 Syntax: EHLO hostname") + writeLine("501 Syntax: HELO hostname") break } if props.FailOnHelo { @@ -3605,6 +3270,16 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server break } writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) + case strings.HasPrefix(data, "EHLO"): + if len(strings.Split(data, " ")) != 2 { + writeLine("501 Syntax: EHLO hostname") + break + } + if props.FailOnEhlo { + writeLine("500 5.5.2 Error: fail on EHLO") + break + } + writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) case strings.HasPrefix(data, "MAIL FROM:"): if props.FailOnMailFrom { writeLine("500 5.5.2 Error: fail on MAIL FROM") @@ -3886,3 +3561,17 @@ type randReader struct{} func (r *randReader) Read([]byte) (int, error) { return 0, errors.New("broken reader") } + +// toServerEmptyAuth is an implementation of Auth that only implements +// the Start method, and returns "FOOAUTH", nil, nil. Notably, it returns +// zero bytes for "toServer" so we can test that we don't send spaces at +// the end of the line. See TestClientAuthTrimSpace. +type toServerEmptyAuth struct{} + +func (toServerEmptyAuth) Start(_ *ServerInfo) (proto string, toServer []byte, err error) { + return "FOOAUTH", nil, nil +} + +func (toServerEmptyAuth) Next(_ []byte, _ bool) (toServer []byte, err error) { + return nil, fmt.Errorf("unexpected call") +} From af7964450a47848ab3c2731789561ee46580fd04 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sat, 9 Nov 2024 14:01:02 +0100 Subject: [PATCH 16/45] Add test cases for invalid HELO/EHLO commands Add tests to ensure HELO/EHLO commands fail with empty name, newline in name, and double execution. This improves validation and robustness of the SMTP client implementation. --- smtp/smtp_test.go | 81 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index da7ad8b..194315f 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1585,6 +1585,87 @@ func TestClient_Hello(t *testing.T) { t.Errorf("failed to send HELO/EHLO to test server: %s", err) } }) + t.Run("client HELO/EHLO with empty name should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Hello(""); err == nil { + t.Error("HELO/EHLO with empty name should fail") + } + }) + t.Run("client HELO/EHLO with newline in name should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Hello(TestServerAddr + "\r\n"); err == nil { + t.Error("HELO/EHLO with newline should fail") + } + }) + t.Run("client double HELO/EHLO should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Hello(TestServerAddr); err != nil { + t.Errorf("failed to send HELO/EHLO to test server: %s", err) + } + if err = client.Hello(TestServerAddr); err == nil { + t.Error("double HELO/EHLO should fail") + } + }) } // Issue 17794: don't send a trailing space on AUTH command when there's no password. From 50505e1339a1fb3ea2cf17f1b868bb5b0da70db8 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sat, 9 Nov 2024 14:26:44 +0100 Subject: [PATCH 17/45] Add test for Client cmd failure on textproto command This commit introduces a new test case for the `Client`'s `cmd` method to ensure it fails correctly when the `textproto` command encounters a broken writer. It also adds a `failWriter` struct that simulates a write failure to facilitate this testing scenario. --- smtp/smtp_test.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 194315f..d7a2a43 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1668,6 +1668,29 @@ func TestClient_Hello(t *testing.T) { }) } +func TestClient_cmd(t *testing.T) { + t.Run("cmd fails on textproto cmd", func(t *testing.T) { + server := "220 server ready\r\n" + var fake faker + fake.failOnClose = true + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(server), + &failWriter{}, + } + client, err := NewClient(fake, "faker.host") + if err != nil { + t.Errorf("failed to create client: %s", err) + } + _, _, err = client.cmd(250, "HELO faker.host") + if err == nil { + t.Error("cmd should fail on textproto cmd with broken writer") + } + }) +} + // Issue 17794: don't send a trailing space on AUTH command when there's no password. func TestClient_Auth_trimSpace(t *testing.T) { server := "220 hello world\r\n" + @@ -3656,3 +3679,10 @@ func (toServerEmptyAuth) Start(_ *ServerInfo) (proto string, toServer []byte, er func (toServerEmptyAuth) Next(_ []byte, _ bool) (toServer []byte, err error) { return nil, fmt.Errorf("unexpected call") } + +// failWriter is a struct type that implements the io.Writer interface, but always returns an error on Write. +type failWriter struct{} + +func (w *failWriter) Write([]byte) (int, error) { + return 0, errors.New("broken writer") +} From 8f28babc473bf970567ad7cf99fcfb255c00e001 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sat, 9 Nov 2024 14:58:23 +0100 Subject: [PATCH 18/45] Add tests for Client StartTLS functionality Introduce new tests to cover the Client's behavior when initiating a STARTTLS session under different conditions: normal operation, failure on EHLO/HELO, and a server not supporting STARTTLS. This ensures robustness in handling STARTTLS interactions. --- smtp/smtp_test.go | 102 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index d7a2a43..9dc3c10 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1691,6 +1691,108 @@ func TestClient_cmd(t *testing.T) { }) } +func TestClient_StartTLS(t *testing.T) { + t.Run("normal STARTTLS should succeed", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := &tls.Config{InsecureSkipVerify: true} + if err = client.StartTLS(tlsConfig); err != nil { + t.Errorf("failed to initialize STARTTLS session: %s", err) + } + }) + t.Run("STARTTLS fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := &tls.Config{InsecureSkipVerify: true} + if err = client.StartTLS(tlsConfig); err == nil { + t.Error("STARTTLS should fail on EHLO") + } + }) + t.Run("STARTTLS fails on server not supporting STARTTLS", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnSTARTTLS: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := &tls.Config{InsecureSkipVerify: true} + if err = client.StartTLS(tlsConfig); err == nil { + t.Error("STARTTLS should fail for server not supporting it") + } + }) +} + // Issue 17794: don't send a trailing space on AUTH command when there's no password. func TestClient_Auth_trimSpace(t *testing.T) { server := "220 hello world\r\n" + From b7ffce62aafa76e84d54e46aa8a7b7fab6d3b513 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sat, 9 Nov 2024 15:22:23 +0100 Subject: [PATCH 19/45] Add TLS connection state tests for SMTP client Introduce tests to verify TLS connection state handling in the SMTP client. Ensure that normal TLS connections return a valid state, and non-TLS connections do not wrongly indicate a TLS state. --- smtp/smtp_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 9dc3c10..6281666 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1793,6 +1793,80 @@ func TestClient_StartTLS(t *testing.T) { }) } +func TestClient_TLSConnectionState(t *testing.T) { + t.Run("normal TLS connection should return a state", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12} + if err = client.StartTLS(tlsConfig); err != nil { + t.Errorf("failed to initialize STARTTLS session: %s", err) + } + state, ok := client.TLSConnectionState() + if !ok { + t.Errorf("failed to get TLS connection state") + } + if state.Version < tls.VersionTLS12 { + t.Errorf("TLS connection state version is %d, should be >= %d", state.Version, tls.VersionTLS12) + } + }) + t.Run("no TLS state on non-TLS connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + _, ok := client.TLSConnectionState() + if ok { + t.Error("non-TLS connection should not have TLS connection state") + } + }) +} + // Issue 17794: don't send a trailing space on AUTH command when there's no password. func TestClient_Auth_trimSpace(t *testing.T) { server := "220 hello world\r\n" + From 0df228178a77121bd836cdc2223ca75a309d5575 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sun, 10 Nov 2024 14:06:05 +0100 Subject: [PATCH 20/45] Add extensive client tests for Mail, Verify, and Auth commands Introduce new tests for the SMTP client covering scenarios for Mail, Verify, and Auth commands to ensure correct behavior under various conditions. Updated `simpleSMTPServer` implementation to handle more cases including VRFY and SMTPUTF8. --- smtp/smtp_test.go | 494 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 436 insertions(+), 58 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 6281666..a6185d5 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1867,44 +1867,408 @@ func TestClient_TLSConnectionState(t *testing.T) { }) } -// Issue 17794: don't send a trailing space on AUTH command when there's no password. -func TestClient_Auth_trimSpace(t *testing.T) { - server := "220 hello world\r\n" + - "200 some more" - var wrote strings.Builder - var fake faker - fake.ReadWriter = struct { - io.Reader - io.Writer - }{ - strings.NewReader(server), - &wrote, - } - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - c.tls = true - c.didHello = true - _ = c.Auth(toServerEmptyAuth{}) - if err = c.Close(); err != nil { - t.Errorf("close failed: %s", err) - } - if got, want := wrote.String(), "AUTH FOOAUTH\r\n*\r\nQUIT\r\n"; got != want { - t.Errorf("wrote %q; want %q", got, want) - } +func TestClient_Verify(t *testing.T) { + t.Run("Verify on existing user succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Verify("toni.tester@example.com"); err != nil { + t.Errorf("failed to verify user: %s", err) + } + }) + t.Run("Verify on non-existing user fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + VRFYUserUnknown: true, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Verify("toni.tester@example.com"); err == nil { + t.Error("verify on non-existing user should fail") + } + }) + t.Run("Verify with newlines should fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Verify("toni.tester@example.com\r\n"); err == nil { + t.Error("verify with new lines should fail") + } + }) + t.Run("Verify should fail on HELO/EHLO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Verify("toni.tester@example.com"); err == nil { + t.Error("verify with new lines should fail") + } + }) +} + +func TestClient_Auth(t *testing.T) { + t.Run("Auth fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err = client.Auth(auth); err == nil { + t.Error("auth should fail on EHLO/HELO") + } + }) + t.Run("Auth fails on auth-start", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + auth := LoginAuth("username", "password", "not.localhost.com", false) + if err = client.Auth(auth); err == nil { + t.Error("auth should fail on auth-start, then on quit") + } + expErr := "wrong host name" + if !strings.EqualFold(expErr, err.Error()) { + t.Errorf("expected error: %q, got: %q", expErr, err.Error()) + } + }) + t.Run("Auth fails on auth-start and then on quit", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FailOnQuit: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + auth := LoginAuth("username", "password", "not.localhost.com", false) + if err = client.Auth(auth); err == nil { + t.Error("auth should fail on auth-start, then on quit") + } + expErr := "wrong host name, 500 5.1.2 Error: quit failed" + if !strings.EqualFold(expErr, err.Error()) { + t.Errorf("expected error: %q, got: %q", expErr, err.Error()) + } + }) + // Issue 17794: don't send a trailing space on AUTH command when there's no password. + t.Run("No trailing space on AUTH when there is no password (Issue 17794)", func(t *testing.T) { + server := "220 hello world\r\n" + + "200 some more" + var wrote strings.Builder + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(server), + &wrote, + } + c, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + c.tls = true + c.didHello = true + _ = c.Auth(toServerEmptyAuth{}) + if err = c.Close(); err != nil { + t.Errorf("close failed: %s", err) + } + if got, want := wrote.String(), "AUTH FOOAUTH\r\n*\r\nQUIT\r\n"; got != want { + t.Errorf("wrote %q; want %q", got, want) + } + }) +} + +func TestClient_Mail(t *testing.T) { + t.Run("normal from address succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Mail("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) + } + }) + t.Run("from address with new lines fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Mail("valid-from@domain.tld\r\n"); err == nil { + t.Error("mail from address with new lines should fail") + } + }) + t.Run("from address fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Mail("valid-from@domain.tld"); err == nil { + t.Error("mail from address should fail on EHLO/HELO") + } + }) + t.Run("from address and server supports 8BITMIME", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + EchoCommandAsError: "MAIL FROM:", + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Mail("valid-from@domain.tld"); err == nil { + t.Error("server should echo the command as error but didn't") + } + sent := strings.Replace(err.Error(), "500 ", "", -1) + expected := "MAIL FROM: BODY=8BITMIME" + if !strings.EqualFold(sent, expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, sent) + } + }) + t.Run("from address and server supports SMTPUTF8", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-SMTPUTF8\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + EchoCommandAsError: "MAIL FROM:", + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Mail("valid-from@domain.tld"); err == nil { + t.Error("server should echo the command as error but didn't") + } + sent := strings.Replace(err.Error(), "500 ", "", -1) + expected := "MAIL FROM: SMTPUTF8" + if !strings.EqualFold(sent, expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, sent) + } + }) } /* - - - - - - - - - func TestBasic(t *testing.T) { server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") @@ -3422,26 +3786,28 @@ func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", " // serverProps represents the configuration properties for the SMTP server. type serverProps struct { - FailOnAuth bool - FailOnDataInit bool - FailOnDataClose bool - FailOnDial bool - FailOnEhlo bool - FailOnHelo bool - FailOnMailFrom bool - FailOnNoop bool - FailOnQuit bool - FailOnReset bool - FailOnSTARTTLS bool - FailTemp bool - FeatureSet string - ListenPort int - SSLListener bool - IsSCRAMPlus bool - IsTLS bool - SupportDSN bool - TestSCRAM bool - HashFunc func() hash.Hash + EchoCommandAsError string + FailOnAuth bool + FailOnDataInit bool + FailOnDataClose bool + FailOnDial bool + FailOnEhlo bool + FailOnHelo bool + FailOnMailFrom bool + FailOnNoop bool + FailOnQuit bool + FailOnReset bool + FailOnSTARTTLS bool + FailTemp bool + FeatureSet string + ListenPort int + HashFunc func() hash.Hash + IsSCRAMPlus bool + IsTLS bool + SupportDSN bool + SSLListener bool + TestSCRAM bool + VRFYUserUnknown bool } // simpleSMTPServer starts a simple TCP server that resonds to SMTP commands. @@ -3540,6 +3906,9 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server var datastring string data = strings.TrimSpace(data) switch { + case props.EchoCommandAsError != "" && strings.HasPrefix(data, props.EchoCommandAsError): + writeLine("500 " + data) + break case strings.HasPrefix(data, "HELO"): if len(strings.Split(data, " ")) != 2 { writeLine("501 Syntax: HELO hostname") @@ -3643,8 +4012,17 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server break } writeOK() - case strings.EqualFold(data, "vrfy"): - writeOK() + case strings.HasPrefix(data, "VRFY"): + if props.VRFYUserUnknown { + writeLine("550 5.1.1 User unknown") + break + } + parts := strings.SplitN(data, " ", 2) + if len(parts) != 2 { + writeLine("500 5.0.0 Error: invalid syntax for VRFY") + break + } + writeLine(fmt.Sprintf("250 2.0.0 Ok: %s OK", parts[1])) case strings.EqualFold(data, "rset"): if props.FailOnReset { writeLine("500 5.1.2 Error: reset failed") @@ -3674,7 +4052,7 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server props.IsTLS = true handleTestServerConnection(connection, t, props) default: - writeLine("500 5.5.2 Error: bad syntax") + writeLine("500 5.5.2 Error: bad syntax - " + data) } } } From 0d8d097ae1a5ac8ed739abb2253e53606baa8ade Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sun, 10 Nov 2024 14:10:50 +0100 Subject: [PATCH 21/45] Add tests for DSN, 8BITMIME, and SMTPUTF8 support in SMTP Introduce new test cases to verify the SMTP server's ability to handle DSN, 8BITMIME, and SMTPUTF8 features. These tests ensure correct response behavior when these features are supported by the server. --- smtp/smtp_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index a6185d5..bf31b05 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -2266,6 +2266,72 @@ func TestClient_Mail(t *testing.T) { t.Errorf("expected mail from command to be %q, but sent %q", expected, sent) } }) + t.Run("from address and server supports DSN", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + EchoCommandAsError: "MAIL FROM:", + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + client.dsnmrtype = "FULL" + if err = client.Mail("valid-from@domain.tld"); err == nil { + t.Error("server should echo the command as error but didn't") + } + sent := strings.Replace(err.Error(), "500 ", "", -1) + expected := "MAIL FROM: RET=FULL" + if !strings.EqualFold(sent, expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, sent) + } + }) + t.Run("from address and server supports DSN, SMTPUTF8 and 8BITMIME", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + EchoCommandAsError: "MAIL FROM:", + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + client.dsnmrtype = "FULL" + if err = client.Mail("valid-from@domain.tld"); err == nil { + t.Error("server should echo the command as error but didn't") + } + sent := strings.Replace(err.Error(), "500 ", "", -1) + expected := "MAIL FROM: BODY=8BITMIME SMTPUTF8 RET=FULL" + if !strings.EqualFold(sent, expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, sent) + } + }) } /* From d446b491e2dfbb0f23315080f15014de5c756e11 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sun, 10 Nov 2024 14:31:18 +0100 Subject: [PATCH 22/45] Add client cleanup to SMTP tests and new TestClient_Rcpt Update SMTP tests to use t.Cleanup for client cleanup to ensure proper resource release. Introduce a new test, TestClient_Rcpt, to verify recipient address handling under various conditions. --- smtp/smtp_test.go | 152 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 145 insertions(+), 7 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index bf31b05..a999f4c 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -2142,8 +2142,13 @@ func TestClient_Mail(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) if err = client.Mail("valid-from@domain.tld"); err != nil { t.Errorf("failed to set mail from address: %s", err) } @@ -2168,8 +2173,13 @@ func TestClient_Mail(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) if err = client.Mail("valid-from@domain.tld\r\n"); err == nil { t.Error("mail from address with new lines should fail") } @@ -2196,8 +2206,13 @@ func TestClient_Mail(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) if err = client.Mail("valid-from@domain.tld"); err == nil { t.Error("mail from address should fail on EHLO/HELO") } @@ -2223,8 +2238,13 @@ func TestClient_Mail(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) if err = client.Mail("valid-from@domain.tld"); err == nil { t.Error("server should echo the command as error but didn't") } @@ -2255,8 +2275,13 @@ func TestClient_Mail(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) if err = client.Mail("valid-from@domain.tld"); err == nil { t.Error("server should echo the command as error but didn't") } @@ -2287,8 +2312,13 @@ func TestClient_Mail(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) client.dsnmrtype = "FULL" if err = client.Mail("valid-from@domain.tld"); err == nil { t.Error("server should echo the command as error but didn't") @@ -2320,8 +2350,13 @@ func TestClient_Mail(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) client.dsnmrtype = "FULL" if err = client.Mail("valid-from@domain.tld"); err == nil { t.Error("server should echo the command as error but didn't") @@ -2334,6 +2369,109 @@ func TestClient_Mail(t *testing.T) { }) } +func TestClient_Rcpt(t *testing.T) { + t.Run("normal recipient address succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Rcpt("valid-to@domain.tld"); err != nil { + t.Errorf("failed to set recipient address: %s", err) + } + }) + t.Run("recipient address with newlines fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Rcpt("valid-to@domain.tld\r\n"); err == nil { + t.Error("recpient address with newlines should fail") + } + }) + t.Run("recipient address with DSN", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + EchoCommandAsError: "RCPT TO:", + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Hello(TestServerAddr); err != nil { + t.Fatalf("failed to send hello to test server: %s", err) + } + client.dsnrntype = "SUCCESS" + if err = client.Rcpt("valid-to@domain.tld"); err == nil { + t.Error("recpient address with newlines should fail") + } + sent := strings.Replace(err.Error(), "500 ", "", -1) + expected := "RCPT TO: NOTIFY=SUCCESS" + if !strings.EqualFold(sent, expected) { + t.Errorf("expected rcpt to command to be %q, but sent %q", expected, sent) + } + }) +} + /* func TestBasic(t *testing.T) { server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") From 75bfdd2855390742da580fafa08454b5c9e76437 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 12:28:52 +0100 Subject: [PATCH 23/45] Update smtp tests to use t.Fatalf for critical failures Changed `t.Errorf` to `t.Fatalf` in multiple instances within the test cases. This ensures that the tests halt immediately upon a critical failure, improving test reliability and debugging clarity. --- smtp/smtp_test.go | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index a999f4c..66d3ee5 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1472,7 +1472,7 @@ func TestNewClient(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to create client: %s", err) + t.Fatalf("failed to create client: %s", err) } if err := client.Close(); err != nil { t.Errorf("failed to close client: %s", err) @@ -1550,7 +1550,7 @@ func TestClient_hello(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } if err = client.hello(); err == nil { t.Error("helo should fail on test server") @@ -1579,7 +1579,7 @@ func TestClient_Hello(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } if err = client.Hello(TestServerAddr); err != nil { t.Errorf("failed to send HELO/EHLO to test server: %s", err) @@ -1605,7 +1605,7 @@ func TestClient_Hello(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } if err = client.Hello(""); err == nil { t.Error("HELO/EHLO with empty name should fail") @@ -1657,7 +1657,7 @@ func TestClient_Hello(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } if err = client.Hello(TestServerAddr); err != nil { t.Errorf("failed to send HELO/EHLO to test server: %s", err) @@ -1712,7 +1712,7 @@ func TestClient_StartTLS(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } t.Cleanup(func() { if err = client.Close(); err != nil { @@ -1746,7 +1746,7 @@ func TestClient_StartTLS(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } t.Cleanup(func() { if err = client.Close(); err != nil { @@ -1779,7 +1779,7 @@ func TestClient_StartTLS(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } t.Cleanup(func() { if err = client.Close(); err != nil { @@ -1814,7 +1814,7 @@ func TestClient_TLSConnectionState(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } t.Cleanup(func() { if err = client.Close(); err != nil { @@ -1853,7 +1853,7 @@ func TestClient_TLSConnectionState(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } t.Cleanup(func() { if err = client.Close(); err != nil { @@ -1888,7 +1888,7 @@ func TestClient_Verify(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } t.Cleanup(func() { if err = client.Close(); err != nil { @@ -1920,7 +1920,7 @@ func TestClient_Verify(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } t.Cleanup(func() { if err = client.Close(); err != nil { @@ -1951,7 +1951,7 @@ func TestClient_Verify(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } t.Cleanup(func() { if err = client.Close(); err != nil { @@ -1984,7 +1984,7 @@ func TestClient_Verify(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } t.Cleanup(func() { if err = client.Close(); err != nil { @@ -2020,7 +2020,7 @@ func TestClient_Auth(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } auth := LoginAuth("username", "password", TestServerAddr, false) if err = client.Auth(auth); err == nil { @@ -2048,7 +2048,7 @@ func TestClient_Auth(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } auth := LoginAuth("username", "password", "not.localhost.com", false) if err = client.Auth(auth); err == nil { @@ -2081,7 +2081,7 @@ func TestClient_Auth(t *testing.T) { client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - t.Errorf("failed to dial to test server: %s", err) + t.Fatalf("failed to dial to test server: %s", err) } auth := LoginAuth("username", "password", "not.localhost.com", false) if err = client.Auth(auth); err == nil { From e9c7bdbb4e963c5e489a020a77f7e6c2e64b8abc Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 12:54:10 +0100 Subject: [PATCH 24/45] Refactor TLS config initialization in tests Replace repetitive TLS configuration code with a reusable `getTLSConfig` helper function for consistency and maintainability. Additionally, update port configuration and add new tests for mail data transmission. --- smtp/smtp_test.go | 120 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 100 insertions(+), 20 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 66d3ee5..8613dd1 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -22,6 +22,7 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/tls" + "crypto/x509" "encoding/base64" "errors" "fmt" @@ -42,7 +43,7 @@ const ( // TestServerAddr is the address the simple SMTP test server listens on TestServerAddr = "127.0.0.1" // TestServerPortBase is the base port for the simple SMTP test server - TestServerPortBase = 12025 + TestServerPortBase = 30025 ) // PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances. @@ -1087,13 +1088,8 @@ func TestScramAuth(t *testing.T) { var client *Client switch tt.tls { case true: - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig) + tlsConfig := getTLSConfig(t) + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig) if err != nil { t.Fatalf("failed to dial TLS server: %v", err) } @@ -1161,13 +1157,8 @@ func TestScramAuth(t *testing.T) { var client *Client switch tt.tls { case true: - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), &tlsConfig) + tlsConfig := getTLSConfig(t) + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig) if err != nil { t.Fatalf("failed to dial TLS server: %v", err) } @@ -1719,7 +1710,7 @@ func TestClient_StartTLS(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - tlsConfig := &tls.Config{InsecureSkipVerify: true} + tlsConfig := getTLSConfig(t) if err = client.StartTLS(tlsConfig); err != nil { t.Errorf("failed to initialize STARTTLS session: %s", err) } @@ -1753,7 +1744,7 @@ func TestClient_StartTLS(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - tlsConfig := &tls.Config{InsecureSkipVerify: true} + tlsConfig := getTLSConfig(t) if err = client.StartTLS(tlsConfig); err == nil { t.Error("STARTTLS should fail on EHLO") } @@ -1786,7 +1777,7 @@ func TestClient_StartTLS(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - tlsConfig := &tls.Config{InsecureSkipVerify: true} + tlsConfig := getTLSConfig(t) if err = client.StartTLS(tlsConfig); err == nil { t.Error("STARTTLS should fail for server not supporting it") } @@ -1821,7 +1812,8 @@ func TestClient_TLSConnectionState(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - tlsConfig := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12} + tlsConfig := getTLSConfig(t) + tlsConfig.MinVersion = tls.VersionTLS12 if err = client.StartTLS(tlsConfig); err != nil { t.Errorf("failed to initialize STARTTLS session: %s", err) } @@ -2472,6 +2464,79 @@ func TestClient_Rcpt(t *testing.T) { }) } +func TestClient_Data(t *testing.T) { + t.Run("normal mail data transmission succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + writer, err := client.Data() + if err != nil { + t.Fatalf("failed to create data writer: %s", err) + } + t.Cleanup(func() { + if err = writer.Close(); err != nil { + t.Errorf("failed to close data writer: %s", err) + } + }) + if _, err = writer.Write([]byte("test message")); err != nil { + t.Errorf("failed to write data to test server: %s", err) + } + }) + t.Run("mail data transmission fails on DATA command", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnDataInit: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if _, err = client.Data(); err == nil { + t.Error("expected data writer to fail") + } + }) +} + /* func TestBasic(t *testing.T) { server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") @@ -4251,7 +4316,7 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server break } writeLine("220 Ready to start TLS") - tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}} + tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}, ServerName: "example.com"} connection = tls.Server(connection, tlsConfig) props.IsTLS = true handleTestServerConnection(connection, t, props) @@ -4444,3 +4509,18 @@ type failWriter struct{} func (w *failWriter) Write([]byte) (int, error) { return 0, errors.New("broken writer") } + +func getTLSConfig(t *testing.T) *tls.Config { + t.Helper() + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Fatalf("unable to load host certifcate: %s", err) + } + testRootCAs := x509.NewCertPool() + testRootCAs.AppendCertsFromPEM(localhostCert) + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: testRootCAs, + ServerName: "example.com", + } +} From 3fffcd15f618a8196f90871eba8b51aeaa71f18f Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 12:54:39 +0100 Subject: [PATCH 25/45] Remove deprecated test hook for STARTTLS The testHookStartTLS variable and its related conditional code have been removed from the smtp.go file. This cleanup streamlines the TLS initiation process and removes unnecessary test-specific hooks no longer in use. --- smtp/smtp.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/smtp/smtp.go b/smtp/smtp.go index 4841ec8..943ca9d 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -432,8 +432,6 @@ func (c *Client) Data() (io.WriteCloser, error) { return datacloser, nil } -var testHookStartTLS func(*tls.Config) // nil, except for tests - // SendMail connects to the server at addr, switches to TLS if // possible, authenticates with the optional mechanism a if possible, // and then sends an email from address from, to addresses to, with @@ -475,9 +473,6 @@ func SendMail(addr string, a Auth, from string, to []string, msg []byte) error { } if ok, _ := c.Extension("STARTTLS"); ok { config := &tls.Config{ServerName: c.serverName} - if testHookStartTLS != nil { - testHookStartTLS(config) - } if err = c.StartTLS(config); err != nil { return err } From c3252626e37e3cd71ec0f82b31b55c29a1876b8e Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 13:31:02 +0100 Subject: [PATCH 26/45] Reverted change made in 3fffcd15f618a8196f90871eba8b51aeaa71f18f --- smtp/smtp.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/smtp/smtp.go b/smtp/smtp.go index 943ca9d..4841ec8 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -432,6 +432,8 @@ func (c *Client) Data() (io.WriteCloser, error) { return datacloser, nil } +var testHookStartTLS func(*tls.Config) // nil, except for tests + // SendMail connects to the server at addr, switches to TLS if // possible, authenticates with the optional mechanism a if possible, // and then sends an email from address from, to addresses to, with @@ -473,6 +475,9 @@ func SendMail(addr string, a Auth, from string, to []string, msg []byte) error { } if ok, _ := c.Extension("STARTTLS"); ok { config := &tls.Config{ServerName: c.serverName} + if testHookStartTLS != nil { + testHookStartTLS(config) + } if err = c.StartTLS(config); err != nil { return err } From c6da3936762d5e50bb8d7c1960f4e3a49dba6093 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 13:31:25 +0100 Subject: [PATCH 27/45] Add echo buffer to SMTP server tests Refactored SMTP server tests to use an echo buffer for capturing responses. This allows for better validation of command responses without relying on error messages. Additionally, added a new test case to validate a full SendMail transaction with TLS and authentication. --- smtp/smtp_test.go | 174 ++++++++++++++++++++++++++++++---------------- 1 file changed, 115 insertions(+), 59 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 8613dd1..e7649f8 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -2215,11 +2215,12 @@ func TestClient_Mail(t *testing.T) { PortAdder.Add(1) serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-8BITMIME\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) go func() { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoCommandAsError: "MAIL FROM:", - FeatureSet: featureSet, - ListenPort: serverPort, + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, }, ); err != nil { t.Errorf("failed to start test server: %s", err) @@ -2237,13 +2238,13 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.Mail("valid-from@domain.tld"); err == nil { - t.Error("server should echo the command as error but didn't") + if err = client.Mail("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) } - sent := strings.Replace(err.Error(), "500 ", "", -1) expected := "MAIL FROM: BODY=8BITMIME" - if !strings.EqualFold(sent, expected) { - t.Errorf("expected mail from command to be %q, but sent %q", expected, sent) + resp := strings.Split(echoBuffer.String(), "\r\n") + if !strings.EqualFold(resp[5], expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } }) t.Run("from address and server supports SMTPUTF8", func(t *testing.T) { @@ -2252,11 +2253,12 @@ func TestClient_Mail(t *testing.T) { PortAdder.Add(1) serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-SMTPUTF8\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) go func() { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoCommandAsError: "MAIL FROM:", - FeatureSet: featureSet, - ListenPort: serverPort, + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, }, ); err != nil { t.Errorf("failed to start test server: %s", err) @@ -2274,13 +2276,13 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.Mail("valid-from@domain.tld"); err == nil { - t.Error("server should echo the command as error but didn't") + if err = client.Mail("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) } - sent := strings.Replace(err.Error(), "500 ", "", -1) expected := "MAIL FROM: SMTPUTF8" - if !strings.EqualFold(sent, expected) { - t.Errorf("expected mail from command to be %q, but sent %q", expected, sent) + resp := strings.Split(echoBuffer.String(), "\r\n") + if !strings.EqualFold(resp[5], expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } }) t.Run("from address and server supports DSN", func(t *testing.T) { @@ -2289,11 +2291,12 @@ func TestClient_Mail(t *testing.T) { PortAdder.Add(1) serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) go func() { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoCommandAsError: "MAIL FROM:", - FeatureSet: featureSet, - ListenPort: serverPort, + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, }, ); err != nil { t.Errorf("failed to start test server: %s", err) @@ -2315,10 +2318,10 @@ func TestClient_Mail(t *testing.T) { if err = client.Mail("valid-from@domain.tld"); err == nil { t.Error("server should echo the command as error but didn't") } - sent := strings.Replace(err.Error(), "500 ", "", -1) expected := "MAIL FROM: RET=FULL" - if !strings.EqualFold(sent, expected) { - t.Errorf("expected mail from command to be %q, but sent %q", expected, sent) + resp := strings.Split(echoBuffer.String(), "\r\n") + if !strings.EqualFold(resp[5], expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } }) t.Run("from address and server supports DSN, SMTPUTF8 and 8BITMIME", func(t *testing.T) { @@ -2327,11 +2330,12 @@ func TestClient_Mail(t *testing.T) { PortAdder.Add(1) serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) go func() { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoCommandAsError: "MAIL FROM:", - FeatureSet: featureSet, - ListenPort: serverPort, + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, }, ); err != nil { t.Errorf("failed to start test server: %s", err) @@ -2353,10 +2357,10 @@ func TestClient_Mail(t *testing.T) { if err = client.Mail("valid-from@domain.tld"); err == nil { t.Error("server should echo the command as error but didn't") } - sent := strings.Replace(err.Error(), "500 ", "", -1) expected := "MAIL FROM: BODY=8BITMIME SMTPUTF8 RET=FULL" - if !strings.EqualFold(sent, expected) { - t.Errorf("expected mail from command to be %q, but sent %q", expected, sent) + resp := strings.Split(echoBuffer.String(), "\r\n") + if !strings.EqualFold(resp[7], expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[7]) } }) } @@ -2428,11 +2432,12 @@ func TestClient_Rcpt(t *testing.T) { PortAdder.Add(1) serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) go func() { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoCommandAsError: "RCPT TO:", - FeatureSet: featureSet, - ListenPort: serverPort, + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, }, ); err != nil { t.Errorf("failed to start test server: %s", err) @@ -2456,10 +2461,10 @@ func TestClient_Rcpt(t *testing.T) { if err = client.Rcpt("valid-to@domain.tld"); err == nil { t.Error("recpient address with newlines should fail") } - sent := strings.Replace(err.Error(), "500 ", "", -1) expected := "RCPT TO: NOTIFY=SUCCESS" - if !strings.EqualFold(sent, expected) { - t.Errorf("expected rcpt to command to be %q, but sent %q", expected, sent) + resp := strings.Split(echoBuffer.String(), "\r\n") + if !strings.EqualFold(resp[5], expected) { + t.Errorf("expected rcpt to command to be %q, but sent %q", expected, resp[5]) } }) } @@ -2537,6 +2542,45 @@ func TestClient_Data(t *testing.T) { }) } +func TestSendMail(t *testing.T) { + t.Run("full SendMail transaction with TLS and auth", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) + testHookStartTLS = func(config *tls.Config) { + testConfig := getTLSConfig(t) + config.ServerName = testConfig.ServerName + config.RootCAs = testConfig.RootCAs + config.Certificates = testConfig.Certificates + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, + []byte("test message")); err != nil { + t.Fatalf("failed to send mail: %s", err) + } + resp := strings.Split(echoBuffer.String(), "\r\n") + for i, line := range resp { + t.Logf("response line %d: %q", i, line) + } + }) +} + /* func TestBasic(t *testing.T) { server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") @@ -4055,28 +4099,28 @@ func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", " // serverProps represents the configuration properties for the SMTP server. type serverProps struct { - EchoCommandAsError string - FailOnAuth bool - FailOnDataInit bool - FailOnDataClose bool - FailOnDial bool - FailOnEhlo bool - FailOnHelo bool - FailOnMailFrom bool - FailOnNoop bool - FailOnQuit bool - FailOnReset bool - FailOnSTARTTLS bool - FailTemp bool - FeatureSet string - ListenPort int - HashFunc func() hash.Hash - IsSCRAMPlus bool - IsTLS bool - SupportDSN bool - SSLListener bool - TestSCRAM bool - VRFYUserUnknown bool + EchoBuffer io.Writer + FailOnAuth bool + FailOnDataInit bool + FailOnDataClose bool + FailOnDial bool + FailOnEhlo bool + FailOnHelo bool + FailOnMailFrom bool + FailOnNoop bool + FailOnQuit bool + FailOnReset bool + FailOnSTARTTLS bool + FailTemp bool + FeatureSet string + ListenPort int + HashFunc func() hash.Hash + IsSCRAMPlus bool + IsTLS bool + SupportDSN bool + SSLListener bool + TestSCRAM bool + VRFYUserUnknown bool } // simpleSMTPServer starts a simple TCP server that resonds to SMTP commands. @@ -4151,6 +4195,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server if err != nil { t.Logf("failed to write line: %s", err) } + if props.EchoBuffer != nil { + if _, err := props.EchoBuffer.Write([]byte(data + "\r\n")); err != nil { + t.Errorf("failed write to echo buffer: %s", err) + } + } _ = writer.Flush() } writeOK := func() { @@ -4171,13 +4220,15 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server break } time.Sleep(time.Millisecond) + if props.EchoBuffer != nil { + if _, err = props.EchoBuffer.Write([]byte(data)); err != nil { + t.Errorf("failed write to echo buffer: %s", err) + } + } var datastring string data = strings.TrimSpace(data) switch { - case props.EchoCommandAsError != "" && strings.HasPrefix(data, props.EchoCommandAsError): - writeLine("500 " + data) - break case strings.HasPrefix(data, "HELO"): if len(strings.Split(data, " ")) != 2 { writeLine("501 Syntax: HELO hostname") @@ -4260,6 +4311,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server t.Logf("failed to read data from connection: %s", derr) break } + if props.EchoBuffer != nil { + if _, err = props.EchoBuffer.Write([]byte(ddata)); err != nil { + t.Errorf("failed write to echo buffer: %s", err) + } + } ddata = strings.TrimSpace(ddata) if ddata == "." { if props.FailOnDataClose { From 87accd289ef855ca876d784ea8b68d5c9326346d Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 16:41:44 +0100 Subject: [PATCH 28/45] Fix formatting in smtp_test.go and add new test cases Corrected indentation inconsistencies in the smtp_test.go file. Added multiple test cases to verify the failure scenarios for `SendMail` under various conditions including invalid hostnames, newlines in the address fields, and server failures during EHLO/HELO, STARTTLS, and authentication stages. --- smtp/smtp_test.go | 265 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 253 insertions(+), 12 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index e7649f8..b09eda6 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -379,7 +379,8 @@ func TestPlainAuth(t *testing.T) { go func() { if err := simpleSMTPServer(ctx, t, &serverProps{ FeatureSet: featureSet, - ListenPort: serverPort}, + ListenPort: serverPort, + }, ); err != nil { t.Errorf("failed to start test server: %s", err) return @@ -411,7 +412,8 @@ func TestPlainAuth(t *testing.T) { if err := simpleSMTPServer(ctx, t, &serverProps{ FailOnAuth: true, FeatureSet: featureSet, - ListenPort: serverPort}, + ListenPort: serverPort, + }, ); err != nil { t.Errorf("failed to start test server: %s", err) return @@ -532,7 +534,8 @@ func TestPlainAuth_noEnc(t *testing.T) { go func() { if err := simpleSMTPServer(ctx, t, &serverProps{ FeatureSet: featureSet, - ListenPort: serverPort}, + ListenPort: serverPort, + }, ); err != nil { t.Errorf("failed to start test server: %s", err) return @@ -564,7 +567,8 @@ func TestPlainAuth_noEnc(t *testing.T) { if err := simpleSMTPServer(ctx, t, &serverProps{ FailOnAuth: true, FeatureSet: featureSet, - ListenPort: serverPort}, + ListenPort: serverPort, + }, ); err != nil { t.Errorf("failed to start test server: %s", err) return @@ -681,7 +685,8 @@ func TestLoginAuth(t *testing.T) { go func() { if err := simpleSMTPServer(ctx, t, &serverProps{ FeatureSet: featureSet, - ListenPort: serverPort}, + ListenPort: serverPort, + }, ); err != nil { t.Errorf("failed to start test server: %s", err) return @@ -713,7 +718,8 @@ func TestLoginAuth(t *testing.T) { if err := simpleSMTPServer(ctx, t, &serverProps{ FailOnAuth: true, FeatureSet: featureSet, - ListenPort: serverPort}, + ListenPort: serverPort, + }, ); err != nil { t.Errorf("failed to start test server: %s", err) return @@ -827,7 +833,8 @@ func TestLoginAuth_noEnc(t *testing.T) { go func() { if err := simpleSMTPServer(ctx, t, &serverProps{ FeatureSet: featureSet, - ListenPort: serverPort}, + ListenPort: serverPort, + }, ); err != nil { t.Errorf("failed to start test server: %s", err) return @@ -859,7 +866,8 @@ func TestLoginAuth_noEnc(t *testing.T) { if err := simpleSMTPServer(ctx, t, &serverProps{ FailOnAuth: true, FeatureSet: featureSet, - ListenPort: serverPort}, + ListenPort: serverPort, + }, ); err != nil { t.Errorf("failed to start test server: %s", err) return @@ -996,7 +1004,8 @@ func TestXOAuth2Auth(t *testing.T) { go func() { if err := simpleSMTPServer(ctx, t, &serverProps{ FeatureSet: featureSet, - ListenPort: serverPort}, + ListenPort: serverPort, + }, ); err != nil { t.Errorf("failed to start test server: %s", err) return @@ -1028,7 +1037,8 @@ func TestXOAuth2Auth(t *testing.T) { if err := simpleSMTPServer(ctx, t, &serverProps{ FailOnAuth: true, FeatureSet: featureSet, - ListenPort: serverPort}, + ListenPort: serverPort, + }, ); err != nil { t.Errorf("failed to start test server: %s", err) return @@ -2544,6 +2554,34 @@ func TestClient_Data(t *testing.T) { func TestSendMail(t *testing.T) { t.Run("full SendMail transaction with TLS and auth", func(t *testing.T) { + want := []string{ + "220 go-mail test server ready ESMTP", + "EHLO localhost", + "250-localhost.localdomain", + "250-AUTH LOGIN", + "250-DSN", + "250 STARTTLS", + "STARTTLS", + "220 Ready to start TLS", + "EHLO localhost", + "250-localhost.localdomain", + "250-AUTH LOGIN", + "250-DSN", + "250 STARTTLS", + "AUTH LOGIN", + "235 2.7.0 Authentication successful", + "MAIL FROM:", + "250 2.0.0 OK", + "RCPT TO:", + "250 2.0.0 OK", + "DATA", + "354 End data with .", + "test message", + ".", + "250 2.0.0 Ok: queued as 1234567890", + "QUIT", + "221 2.0.0 Bye", + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() PortAdder.Add(1) @@ -2575,8 +2613,211 @@ func TestSendMail(t *testing.T) { t.Fatalf("failed to send mail: %s", err) } resp := strings.Split(echoBuffer.String(), "\r\n") - for i, line := range resp { - t.Logf("response line %d: %q", i, line) + if len(resp)-1 != len(want) { + t.Fatalf("expected %d lines, but got %d", len(want), len(resp)) + } + for i := 0; i < len(want); i++ { + if !strings.EqualFold(resp[i], want[i]) { + t.Errorf("expected line %d to be %q, but got %q", i, resp[i], want[i]) + } + } + }) + t.Run("SendMail newline in from should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) + testHookStartTLS = func(config *tls.Config) { + testConfig := getTLSConfig(t) + config.ServerName = testConfig.ServerName + config.RootCAs = testConfig.RootCAs + config.Certificates = testConfig.Certificates + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld\r\n", []string{"valid-to@domain.tld"}, + []byte("test message")); err == nil { + t.Error("expected SendMail to fail with newlines in from address") + } + }) + t.Run("SendMail newline in to should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) + testHookStartTLS = func(config *tls.Config) { + testConfig := getTLSConfig(t) + config.ServerName = testConfig.ServerName + config.RootCAs = testConfig.RootCAs + config.Certificates = testConfig.Certificates + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld\r\n"}, + []byte("test message")); err == nil { + t.Error("expected SendMail to fail with newlines in to address") + } + }) + t.Run("SendMail with invalid hostname should fail", func(t *testing.T) { + addr := "invalid.invalid-hostname.tld:1234" + testHookStartTLS = func(config *tls.Config) { + testConfig := getTLSConfig(t) + config.ServerName = testConfig.ServerName + config.RootCAs = testConfig.RootCAs + config.Certificates = testConfig.Certificates + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, + []byte("test message")); err == nil { + t.Error("expected SendMail to fail with invalid server address") + } + }) + t.Run("SendMail should fail on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) + testHookStartTLS = func(config *tls.Config) { + testConfig := getTLSConfig(t) + config.ServerName = testConfig.ServerName + config.RootCAs = testConfig.RootCAs + config.Certificates = testConfig.Certificates + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, + []byte("test message")); err == nil { + t.Error("expected SendMail to fail on EHLO/HELO") + } + }) + t.Run("SendMail should fail on STARTTLS", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + testHookStartTLS = func(config *tls.Config) { + config.ServerName = "invalid.invalid-hostname.tld" + config.RootCAs = nil + config.Certificates = nil + } + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, + []byte("test message")); err == nil { + t.Error("expected SendMail to fail on STARTTLS") + } + }) + t.Run("SendMail should fail on no auth support", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) + testHookStartTLS = func(config *tls.Config) { + testConfig := getTLSConfig(t) + config.ServerName = testConfig.ServerName + config.RootCAs = testConfig.RootCAs + config.Certificates = testConfig.Certificates + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, + []byte("test message")); err == nil { + t.Error("expected SendMail to fail on no auth support") + } + }) + t.Run("SendMail should fail on auth", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) + testHookStartTLS = func(config *tls.Config) { + testConfig := getTLSConfig(t) + config.ServerName = testConfig.ServerName + config.RootCAs = testConfig.RootCAs + config.Certificates = testConfig.Certificates + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, + []byte("test message")); err == nil { + t.Error("expected SendMail to fail on auth") } }) } From 1bfec504ed38d6ca6b3274b8ff5aa5267a0c37fa Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 17:12:15 +0100 Subject: [PATCH 29/45] Add new SMTP test cases for various failure scenarios This commit introduces multiple test cases to handle SMTP failure scenarios such as invalid host, newline in addresses, EHLO/HELO failures, and malformed data injections. Additionally, it includes improvements in error handling for these specific conditions. --- smtp/smtp_test.go | 576 ++++++++++++++++++++-------------------------- 1 file changed, 246 insertions(+), 330 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index b09eda6..c54b868 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -2295,6 +2295,44 @@ func TestClient_Mail(t *testing.T) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } }) + t.Run("from address and server supports SMTPUTF8 with unicode address", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-SMTPUTF8\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Mail("valid-from+📧@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) + } + expected := "MAIL FROM: SMTPUTF8" + resp := strings.Split(echoBuffer.String(), "\r\n") + if !strings.EqualFold(resp[5], expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) + } + }) t.Run("from address and server supports DSN", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2325,8 +2363,8 @@ func TestClient_Mail(t *testing.T) { } }) client.dsnmrtype = "FULL" - if err = client.Mail("valid-from@domain.tld"); err == nil { - t.Error("server should echo the command as error but didn't") + if err = client.Mail("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: RET=FULL" resp := strings.Split(echoBuffer.String(), "\r\n") @@ -2364,8 +2402,8 @@ func TestClient_Mail(t *testing.T) { } }) client.dsnmrtype = "FULL" - if err = client.Mail("valid-from@domain.tld"); err == nil { - t.Error("server should echo the command as error but didn't") + if err = client.Mail("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: BODY=8BITMIME SMTPUTF8 RET=FULL" resp := strings.Split(echoBuffer.String(), "\r\n") @@ -2553,6 +2591,153 @@ func TestClient_Data(t *testing.T) { } func TestSendMail(t *testing.T) { + tests := []struct { + name string + featureSet string + hostname string + tlsConfig *tls.Config + props *serverProps + fromAddr string + toAddr string + message []byte + }{ + { + "fail on newline in MAIL FROM address", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{}, + "valid-from@domain.tld\r\n", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on newline in RCPT TO address", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{}, + "valid-from@domain.tld", + "valid-to@domain.tld\r\n", + []byte("test message"), + }, + { + "fail on invalid host address", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + "invalid.invalid-host@domain.tld", + getTLSConfig(t), + &serverProps{}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on EHLO/HELO", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnEhlo: true, FailOnHelo: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on STARTTLS", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + &tls.Config{ServerName: "invalid.invalid-host@domain.tld"}, + &serverProps{}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on no server AUTH support", + "250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on AUTH", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnAuth: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on MAIL FROM", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnMailFrom: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on RCPT TO", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnRcptTo: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on DATA (init phase)", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnDataInit: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on DATA (closing phase)", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnDataClose: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + tt.props.ListenPort = int(TestServerPortBase + PortAdder.Load()) + tt.props.FeatureSet = tt.featureSet + go func() { + if err := simpleSMTPServer(ctx, t, tt.props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", tt.hostname, tt.props.ListenPort) + testHookStartTLS = func(config *tls.Config) { + config.ServerName = tt.tlsConfig.ServerName + config.RootCAs = tt.tlsConfig.RootCAs + config.Certificates = tt.tlsConfig.Certificates + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, tt.fromAddr, []string{tt.toAddr}, tt.message); err == nil { + t.Error("expected SendMail to " + tt.name) + } + }) + } t.Run("full SendMail transaction with TLS and auth", func(t *testing.T) { want := []string{ "220 go-mail test server ready ESMTP", @@ -2622,37 +2807,41 @@ func TestSendMail(t *testing.T) { } } }) - t.Run("SendMail newline in from should fail", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - PortAdder.Add(1) - serverPort := int(TestServerPortBase + PortAdder.Load()) - featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" - go func() { - if err := simpleSMTPServer(ctx, t, &serverProps{ - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 30) - addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) - testHookStartTLS = func(config *tls.Config) { - testConfig := getTLSConfig(t) - config.ServerName = testConfig.ServerName - config.RootCAs = testConfig.RootCAs - config.Certificates = testConfig.Certificates + t.Run("full SendMail transaction with leading dots", func(t *testing.T) { + want := []string{ + "220 go-mail test server ready ESMTP", + "EHLO localhost", + "250-localhost.localdomain", + "250-AUTH LOGIN", + "250-DSN", + "250 STARTTLS", + "STARTTLS", + "220 Ready to start TLS", + "EHLO localhost", + "250-localhost.localdomain", + "250-AUTH LOGIN", + "250-DSN", + "250 STARTTLS", + "AUTH LOGIN", + "235 2.7.0 Authentication successful", + "MAIL FROM:", + "250 2.0.0 OK", + "RCPT TO:", + "250 2.0.0 OK", + "DATA", + "354 End data with .", + "From: user@gmail.com", + "To: golang-nuts@googlegroups.com", + "Subject: Hooray for Go", + "", + "Line 1", + "..Leading dot line .", + "Goodbye.", + ".", + "250 2.0.0 Ok: queued as 1234567890", + "QUIT", + "221 2.0.0 Bye", } - auth := LoginAuth("username", "password", TestServerAddr, false) - if err := SendMail(addr, auth, "valid-from@domain.tld\r\n", []string{"valid-to@domain.tld"}, - []byte("test message")); err == nil { - t.Error("expected SendMail to fail with newlines in from address") - } - }) - t.Run("SendMail newline in to should fail", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() PortAdder.Add(1) @@ -2678,309 +2867,31 @@ func TestSendMail(t *testing.T) { config.RootCAs = testConfig.RootCAs config.Certificates = testConfig.Certificates } - auth := LoginAuth("username", "password", TestServerAddr, false) - if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld\r\n"}, - []byte("test message")); err == nil { - t.Error("expected SendMail to fail with newlines in to address") - } - }) - t.Run("SendMail with invalid hostname should fail", func(t *testing.T) { - addr := "invalid.invalid-hostname.tld:1234" - testHookStartTLS = func(config *tls.Config) { - testConfig := getTLSConfig(t) - config.ServerName = testConfig.ServerName - config.RootCAs = testConfig.RootCAs - config.Certificates = testConfig.Certificates - } - auth := LoginAuth("username", "password", TestServerAddr, false) - if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, - []byte("test message")); err == nil { - t.Error("expected SendMail to fail with invalid server address") - } - }) - t.Run("SendMail should fail on EHLO/HELO", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - PortAdder.Add(1) - serverPort := int(TestServerPortBase + PortAdder.Load()) - featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" - go func() { - if err := simpleSMTPServer(ctx, t, &serverProps{ - FailOnEhlo: true, - FailOnHelo: true, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 30) - addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) - testHookStartTLS = func(config *tls.Config) { - testConfig := getTLSConfig(t) - config.ServerName = testConfig.ServerName - config.RootCAs = testConfig.RootCAs - config.Certificates = testConfig.Certificates - } - auth := LoginAuth("username", "password", TestServerAddr, false) - if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, - []byte("test message")); err == nil { - t.Error("expected SendMail to fail on EHLO/HELO") - } - }) - t.Run("SendMail should fail on STARTTLS", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - PortAdder.Add(1) - serverPort := int(TestServerPortBase + PortAdder.Load()) - featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" - go func() { - if err := simpleSMTPServer(ctx, t, &serverProps{ - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - testHookStartTLS = func(config *tls.Config) { - config.ServerName = "invalid.invalid-hostname.tld" - config.RootCAs = nil - config.Certificates = nil - } - time.Sleep(time.Millisecond * 30) - addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) - auth := LoginAuth("username", "password", TestServerAddr, false) - if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, - []byte("test message")); err == nil { - t.Error("expected SendMail to fail on STARTTLS") - } - }) - t.Run("SendMail should fail on no auth support", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - PortAdder.Add(1) - serverPort := int(TestServerPortBase + PortAdder.Load()) - featureSet := "250-DSN\r\n250 STARTTLS" - go func() { - if err := simpleSMTPServer(ctx, t, &serverProps{ - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 30) - addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) - testHookStartTLS = func(config *tls.Config) { - testConfig := getTLSConfig(t) - config.ServerName = testConfig.ServerName - config.RootCAs = testConfig.RootCAs - config.Certificates = testConfig.Certificates - } - auth := LoginAuth("username", "password", TestServerAddr, false) - if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, - []byte("test message")); err == nil { - t.Error("expected SendMail to fail on no auth support") - } - }) - t.Run("SendMail should fail on auth", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - PortAdder.Add(1) - serverPort := int(TestServerPortBase + PortAdder.Load()) - featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" - go func() { - if err := simpleSMTPServer(ctx, t, &serverProps{ - FailOnAuth: true, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 30) - addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) - testHookStartTLS = func(config *tls.Config) { - testConfig := getTLSConfig(t) - config.ServerName = testConfig.ServerName - config.RootCAs = testConfig.RootCAs - config.Certificates = testConfig.Certificates - } - auth := LoginAuth("username", "password", TestServerAddr, false) - if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, - []byte("test message")); err == nil { - t.Error("expected SendMail to fail on auth") - } - }) -} - -/* -func TestBasic(t *testing.T) { - server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c := &Client{Text: textproto.NewConn(fake), localName: "localhost"} - - if err := c.helo(); err != nil { - t.Fatalf("HELO failed: %s", err) - } - if err := c.ehlo(); err == nil { - t.Fatalf("Expected first EHLO to fail") - } - if err := c.ehlo(); err != nil { - t.Fatalf("Second EHLO failed: %s", err) - } - - c.didHello = true - if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" { - t.Fatalf("Expected AUTH supported") - } - if ok, _ := c.Extension("DSN"); ok { - t.Fatalf("Shouldn't support DSN") - } - - if err := c.Mail("user@gmail.com"); err == nil { - t.Fatalf("MAIL should require authentication") - } - - if err := c.Verify("user1@gmail.com"); err == nil { - t.Fatalf("First VRFY: expected no verification") - } - if err := c.Verify("user2@gmail.com>\r\nDATA\r\nAnother injected message body\r\n.\r\nQUIT\r\n"); err == nil { - t.Fatalf("VRFY should have failed due to a message injection attempt") - } - if err := c.Verify("user2@gmail.com"); err != nil { - t.Fatalf("Second VRFY: expected verification, got %s", err) - } - - // fake TLS so authentication won't complain - c.tls = true - c.serverName = "smtp.google.com" - if err := c.Auth(PlainAuth("", "user", "pass", "smtp.google.com", false)); err != nil { - t.Fatalf("AUTH failed: %s", err) - } - - if err := c.Rcpt("golang-nuts@googlegroups.com>\r\nDATA\r\nInjected message body\r\n.\r\nQUIT\r\n"); err == nil { - t.Fatalf("RCPT should have failed due to a message injection attempt") - } - if err := c.Mail("user@gmail.com>\r\nDATA\r\nAnother injected message body\r\n.\r\nQUIT\r\n"); err == nil { - t.Fatalf("MAIL should have failed due to a message injection attempt") - } - if err := c.Mail("user@gmail.com"); err != nil { - t.Fatalf("MAIL failed: %s", err) - } - if err := c.Rcpt("golang-nuts@googlegroups.com"); err != nil { - t.Fatalf("RCPT failed: %s", err) - } - msg := `From: user@gmail.com + message := []byte(`From: user@gmail.com To: golang-nuts@googlegroups.com Subject: Hooray for Go Line 1 .Leading dot line . -Goodbye.` - w, err := c.Data() - if err != nil { - t.Fatalf("DATA failed: %s", err) - } - if _, err := w.Write([]byte(msg)); err != nil { - t.Fatalf("Data write failed: %s", err) - } - if err := w.Close(); err != nil { - t.Fatalf("Bad data response: %s", err) - } - - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } +Goodbye.`) + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, message); err != nil { + t.Fatalf("failed to send mail: %s", err) + } + resp := strings.Split(echoBuffer.String(), "\r\n") + if len(resp)-1 != len(want) { + t.Errorf("expected %d lines, but got %d", len(want), len(resp)) + } + for i := 0; i < len(want); i++ { + if !strings.EqualFold(resp[i], want[i]) { + t.Errorf("expected line %d to be %q, but got %q", i, resp[i], want[i]) + } + } + }) } -var basicServer = `250 mx.google.com at your service -502 Unrecognized command. -250-mx.google.com at your service -250-SIZE 35651584 -250-AUTH LOGIN PLAIN -250 8BITMIME -530 Authentication required -252 Send some mail, I'll try my best -250 User is valid -235 Accepted -250 Sender OK -250 Receiver OK -354 Go ahead -250 Data OK -221 OK -` +/* -var basicClient = `HELO localhost -EHLO localhost -EHLO localhost -MAIL FROM: BODY=8BITMIME -VRFY user1@gmail.com -VRFY user2@gmail.com -AUTH PLAIN AHVzZXIAcGFzcw== -MAIL FROM: BODY=8BITMIME -RCPT TO: -DATA -From: user@gmail.com -To: golang-nuts@googlegroups.com -Subject: Hooray for Go - -Line 1 -..Leading dot line . -Goodbye. -. -QUIT -` - -func TestHELOFailed(t *testing.T) { - serverLines := `502 EH? -502 EH? -221 OK -` - clientLines := `EHLO localhost -HELO localhost -QUIT -` - server := strings.Join(strings.Split(serverLines, "\n"), "\r\n") - client := strings.Join(strings.Split(clientLines, "\n"), "\r\n") - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c := &Client{Text: textproto.NewConn(fake), localName: "localhost"} - if err := c.Hello("localhost"); err == nil { - t.Fatal("expected EHLO to fail") - } - if err := c.Quit(); err != nil { - t.Errorf("QUIT failed: %s", err) - } - _ = bcmdbuf.Flush() - actual := cmdbuf.String() - if client != actual { - t.Errorf("Got:\n%s\nWant:\n%s", actual, client) - } -} func TestExtensions(t *testing.T) { fake := func(server string) (c *Client, bcmdbuf *bufio.Writer, cmdbuf *strings.Builder) { @@ -4351,6 +4262,7 @@ type serverProps struct { FailOnNoop bool FailOnQuit bool FailOnReset bool + FailOnRcptTo bool FailOnSTARTTLS bool FailTemp bool FeatureSet string @@ -4502,12 +4414,16 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server from = strings.ReplaceAll(from, "RET=FULL", "") } from = strings.TrimSpace(from) - if !strings.EqualFold(from, "") { + if !strings.HasPrefix(from, "") { writeLine(fmt.Sprintf("503 5.1.2 Invalid from: %s", from)) break } writeOK() case strings.HasPrefix(data, "RCPT TO:"): + if props.FailOnRcptTo { + writeLine("500 5.5.2 Error: fail on RCPT TO") + break + } to := strings.TrimPrefix(data, "RCPT TO:") if props.SupportDSN { to = strings.ReplaceAll(to, "NOTIFY=FAILURE,SUCCESS", "") From 2d384a7d3765878001d9dbf4fc86267f4014738f Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 17:29:43 +0100 Subject: [PATCH 30/45] Add unit tests for SMTP client extensions, reset, and noop Introduced new unit tests to verify the SMTP client's behavior with extensions, reset, and noop commands under various server conditions. Updated server response handling to correctly manage feature sets when they are empty. --- smtp/smtp_test.go | 239 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 237 insertions(+), 2 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index c54b868..401a6ac 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -2890,6 +2890,233 @@ Goodbye.`) }) } +func TestClient_Extension(t *testing.T) { + t.Run("extension check fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if ok, _ := client.Extension("DSN"); ok { + t.Error("expected client extension check to fail on EHLO/HELO") + } + }) +} + +func TestClient_Reset(t *testing.T) { + t.Run("reset on functioning client conneciton", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Reset(); err != nil { + t.Errorf("failed to reset client: %s", err) + } + }) + t.Run("reset fails on RSET", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnReset: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Reset(); err == nil { + t.Error("expected client reset to fail") + } + }) + t.Run("reset fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Reset(); err == nil { + t.Error("expected client reset to fail") + } + }) +} + +func TestClient_Noop(t *testing.T) { + t.Run("noop on functioning client conneciton", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Noop(); err != nil { + t.Errorf("failed client no-operation: %s", err) + } + }) + t.Run("noop fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Noop(); err == nil { + t.Error("expected client no-operation to fail") + } + }) + t.Run("noop fails on NOOP", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnNoop: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Noop(); err == nil { + t.Error("expected client no-operation to fail") + } + }) +} + /* @@ -4391,7 +4618,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server writeLine("500 5.5.2 Error: fail on HELO") break } - writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) + if props.FeatureSet != "" { + writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) + break + } + writeLine("250 localhost.localdomain\r\n") case strings.HasPrefix(data, "EHLO"): if len(strings.Split(data, " ")) != 2 { writeLine("501 Syntax: EHLO hostname") @@ -4401,7 +4632,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server writeLine("500 5.5.2 Error: fail on EHLO") break } - writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) + if props.FeatureSet != "" { + writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) + break + } + writeLine("250 localhost.localdomain\r\n") case strings.HasPrefix(data, "MAIL FROM:"): if props.FailOnMailFrom { writeLine("500 5.5.2 Error: fail on MAIL FROM") From 2cc670659eb16857018560b3d624b31c880fb14f Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 17:36:24 +0100 Subject: [PATCH 31/45] Remove obsolete SMTP client tests Deleted multiple outdated and redundant tests from the SMTP client test suite to streamline and improve the maintainability of the test codebase. This change focuses on removing tests that are no longer relevant or have been superseded by other methods. --- smtp/smtp_test.go | 1337 --------------------------------------------- 1 file changed, 1337 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 401a6ac..11cc2a0 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -3117,1343 +3117,6 @@ func TestClient_Noop(t *testing.T) { }) } -/* - - -func TestExtensions(t *testing.T) { - fake := func(server string) (c *Client, bcmdbuf *bufio.Writer, cmdbuf *strings.Builder) { - server = strings.Join(strings.Split(server, "\n"), "\r\n") - - cmdbuf = &strings.Builder{} - bcmdbuf = bufio.NewWriter(cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c = &Client{Text: textproto.NewConn(fake), localName: "localhost"} - - return c, bcmdbuf, cmdbuf - } - - t.Run("helo", func(t *testing.T) { - const ( - basicServer = `250 mx.google.com at your service -250 Sender OK -221 Goodbye -` - - basicClient = `HELO localhost -MAIL FROM: -QUIT -` - ) - - c, bcmdbuf, cmdbuf := fake(basicServer) - - if err := c.helo(); err != nil { - t.Fatalf("HELO failed: %s", err) - } - c.didHello = true - if err := c.Mail("user@gmail.com"); err != nil { - t.Fatalf("MAIL FROM failed: %s", err) - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - }) - - t.Run("ehlo", func(t *testing.T) { - const ( - basicServer = `250-mx.google.com at your service -250 SIZE 35651584 -250 Sender OK -221 Goodbye -` - - basicClient = `EHLO localhost -MAIL FROM: -QUIT -` - ) - - c, bcmdbuf, cmdbuf := fake(basicServer) - - if err := c.Hello("localhost"); err != nil { - t.Fatalf("EHLO failed: %s", err) - } - if ok, _ := c.Extension("8BITMIME"); ok { - t.Fatalf("Shouldn't support 8BITMIME") - } - if ok, _ := c.Extension("SMTPUTF8"); ok { - t.Fatalf("Shouldn't support SMTPUTF8") - } - if err := c.Mail("user@gmail.com"); err != nil { - t.Fatalf("MAIL FROM failed: %s", err) - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - }) - - t.Run("ehlo 8bitmime", func(t *testing.T) { - const ( - basicServer = `250-mx.google.com at your service -250-SIZE 35651584 -250 8BITMIME -250 Sender OK -221 Goodbye -` - - basicClient = `EHLO localhost -MAIL FROM: BODY=8BITMIME -QUIT -` - ) - - c, bcmdbuf, cmdbuf := fake(basicServer) - - if err := c.Hello("localhost"); err != nil { - t.Fatalf("EHLO failed: %s", err) - } - if ok, _ := c.Extension("8BITMIME"); !ok { - t.Fatalf("Should support 8BITMIME") - } - if ok, _ := c.Extension("SMTPUTF8"); ok { - t.Fatalf("Shouldn't support SMTPUTF8") - } - if err := c.Mail("user@gmail.com"); err != nil { - t.Fatalf("MAIL FROM failed: %s", err) - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - actualcmds := cmdbuf.String() - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - }) - - t.Run("ehlo smtputf8", func(t *testing.T) { - const ( - basicServer = `250-mx.google.com at your service -250-SIZE 35651584 -250 SMTPUTF8 -250 Sender OK -221 Goodbye -` - - basicClient = `EHLO localhost -MAIL FROM: SMTPUTF8 -QUIT -` - ) - - c, bcmdbuf, cmdbuf := fake(basicServer) - - if err := c.Hello("localhost"); err != nil { - t.Fatalf("EHLO failed: %s", err) - } - if ok, _ := c.Extension("8BITMIME"); ok { - t.Fatalf("Shouldn't support 8BITMIME") - } - if ok, _ := c.Extension("SMTPUTF8"); !ok { - t.Fatalf("Should support SMTPUTF8") - } - if err := c.Mail("user+📧@gmail.com"); err != nil { - t.Fatalf("MAIL FROM failed: %s", err) - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - actualcmds := cmdbuf.String() - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - }) - - t.Run("ehlo 8bitmime smtputf8", func(t *testing.T) { - const ( - basicServer = `250-mx.google.com at your service -250-SIZE 35651584 -250-8BITMIME -250 SMTPUTF8 -250 Sender OK -221 Goodbye - ` - - basicClient = `EHLO localhost -MAIL FROM: BODY=8BITMIME SMTPUTF8 -QUIT -` - ) - - c, bcmdbuf, cmdbuf := fake(basicServer) - - if err := c.Hello("localhost"); err != nil { - t.Fatalf("EHLO failed: %s", err) - } - c.didHello = true - if ok, _ := c.Extension("8BITMIME"); !ok { - t.Fatalf("Should support 8BITMIME") - } - if ok, _ := c.Extension("SMTPUTF8"); !ok { - t.Fatalf("Should support SMTPUTF8") - } - if err := c.Mail("user+📧@gmail.com"); err != nil { - t.Fatalf("MAIL FROM failed: %s", err) - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - actualcmds := cmdbuf.String() - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - }) -} - -func TestNewClient(t *testing.T) { - server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") - client := strings.Join(strings.Split(newClientClient, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - out := func() string { - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - return cmdbuf.String() - } - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v\n(after %v)", err, out()) - } - defer func() { - _ = c.Close() - }() - if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" { - t.Fatalf("Expected AUTH supported") - } - if ok, _ := c.Extension("DSN"); ok { - t.Fatalf("Shouldn't support DSN") - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - actualcmds := out() - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } -} - -// TestClient_SetDebugLog tests the Client method with the Client.SetDebugLog method -// to enable debug logging -func TestClient_SetDebugLog(t *testing.T) { - server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - out := func() string { - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - return cmdbuf.String() - } - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v\n(after %v)", err, out()) - } - defer func() { - _ = c.Close() - }() - c.SetDebugLog(true) - if !c.debug { - t.Errorf("Expected DebugLog flag to be true but received false") - } -} - -// TestClient_SetLogger tests the Client method with the Client.SetLogger method -// to provide a custom logger -func TestClient_SetLogger(t *testing.T) { - server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - out := func() string { - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - return cmdbuf.String() - } - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v\n(after %v)", err, out()) - } - defer func() { - _ = c.Close() - }() - c.SetLogger(log.New(os.Stderr, log.LevelDebug)) - if c.logger == nil { - t.Errorf("Expected Logger to be set but received nil") - } - c.logger.Debugf(log.Log{Direction: log.DirServerToClient, Format: "%s", Messages: []interface{}{"test"}}) - c.SetLogger(nil) - c.logger.Debugf(log.Log{Direction: log.DirServerToClient, Format: "%s", Messages: []interface{}{"test"}}) -} - -func TestClient_SetLogAuthData(t *testing.T) { - server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - out := func() string { - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - return cmdbuf.String() - } - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v\n(after %v)", err, out()) - } - defer func() { - _ = c.Close() - }() - c.SetLogAuthData() - if !c.logAuthData { - t.Error("Expected logAuthData to be true but received false") - } -} - -var newClientServer = `220 hello world -250-mx.google.com at your service -250-SIZE 35651584 -250-AUTH LOGIN PLAIN -250 8BITMIME -221 OK -` - -var newClientClient = `EHLO localhost -QUIT -` - -func TestNewClient2(t *testing.T) { - server := strings.Join(strings.Split(newClient2Server, "\n"), "\r\n") - client := strings.Join(strings.Split(newClient2Client, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - defer func() { - _ = c.Close() - }() - if ok, _ := c.Extension("DSN"); ok { - t.Fatalf("Shouldn't support DSN") - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } -} - -var newClient2Server = `220 hello world -502 EH? -250-mx.google.com at your service -250-SIZE 35651584 -250-AUTH LOGIN PLAIN -250 8BITMIME -221 OK -` - -var newClient2Client = `EHLO localhost -HELO localhost -QUIT -` - -func TestNewClientWithTLS(t *testing.T) { - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - t.Fatalf("loadcert: %v", err) - } - - config := tls.Config{Certificates: []tls.Certificate{cert}} - - ln, err := tls.Listen("tcp", "127.0.0.1:0", &config) - if err != nil { - ln, err = tls.Listen("tcp", "[::1]:0", &config) - if err != nil { - t.Fatalf("server: listen: %v", err) - } - } - - go func() { - conn, err := ln.Accept() - if err != nil { - t.Errorf("server: accept: %v", err) - return - } - defer func() { - _ = conn.Close() - }() - - _, err = conn.Write([]byte("220 SIGNS\r\n")) - if err != nil { - t.Errorf("server: write: %v", err) - return - } - }() - - config.InsecureSkipVerify = true - conn, err := tls.Dial("tcp", ln.Addr().String(), &config) - if err != nil { - t.Fatalf("client: dial: %v", err) - } - defer func() { - _ = conn.Close() - }() - - client, err := NewClient(conn, ln.Addr().String()) - if err != nil { - t.Fatalf("smtp: newclient: %v", err) - } - if !client.tls { - t.Errorf("client.tls Got: %t Expected: %t", client.tls, true) - } -} - -func TestHello(t *testing.T) { - if len(helloServer) != len(helloClient) { - t.Fatalf("Hello server and client size mismatch") - } - - tf := func(fake faker, i int) error { - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - defer func() { - _ = c.Close() - }() - c.localName = "customhost" - err = nil - - switch i { - case 0: - err = c.Hello("hostinjection>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n") - if err == nil { - t.Errorf("Expected Hello to be rejected due to a message injection attempt") - } - err = c.Hello("customhost") - case 1: - err = c.StartTLS(nil) - if err.Error() == "502 Not implemented" { - err = nil - } - case 2: - err = c.Verify("test@example.com") - case 3: - c.tls = true - c.serverName = "smtp.google.com" - err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com", false)) - case 4: - err = c.Mail("test@example.com") - case 5: - ok, _ := c.Extension("feature") - if ok { - t.Errorf("Expected FEATURE not to be supported") - } - case 6: - err = c.Reset() - case 7: - err = c.Quit() - case 8: - err = c.Verify("test@example.com") - if err != nil { - err = c.Hello("customhost") - if err != nil { - t.Errorf("Want error, got none") - } - } - case 9: - err = c.Noop() - default: - t.Fatalf("Unhandled command") - } - - if err != nil { - t.Errorf("Command %d failed: %v", i, err) - } - return nil - } - - for i := 0; i < len(helloServer); i++ { - server := strings.Join(strings.Split(baseHelloServer+helloServer[i], "\n"), "\r\n") - client := strings.Join(strings.Split(baseHelloClient+helloClient[i], "\n"), "\r\n") - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - - if err := tf(fake, i); err != nil { - t.Error(err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - if client != actualcmds { - t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - } -} - -var baseHelloServer = `220 hello world -502 EH? -250-mx.google.com at your service -250 FEATURE -` - -var helloServer = []string{ - "", - "502 Not implemented\n", - "250 User is valid\n", - "235 Accepted\n", - "250 Sender ok\n", - "", - "250 Reset ok\n", - "221 Goodbye\n", - "250 Sender ok\n", - "250 ok\n", -} - -var baseHelloClient = `EHLO customhost -HELO customhost -` - -var helloClient = []string{ - "", - "STARTTLS\n", - "VRFY test@example.com\n", - "AUTH PLAIN AHVzZXIAcGFzcw==\n", - "MAIL FROM:\n", - "", - "RSET\n", - "QUIT\n", - "VRFY test@example.com\n", - "NOOP\n", -} - -func TestSendMail(t *testing.T) { - server := strings.Join(strings.Split(sendMailServer, "\n"), "\r\n") - client := strings.Join(strings.Split(sendMailClient, "\n"), "\r\n") - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Unable to create listener: %v", err) - } - defer func() { - _ = l.Close() - }() - - // prevent data race on bcmdbuf - done := make(chan struct{}) - go func(data []string) { - defer close(done) - - conn, err := l.Accept() - if err != nil { - t.Errorf("Accept error: %v", err) - return - } - defer func() { - _ = conn.Close() - }() - - tc := textproto.NewConn(conn) - for i := 0; i < len(data) && data[i] != ""; i++ { - if err := tc.PrintfLine("%s", data[i]); err != nil { - t.Errorf("printing to textproto failed: %s", err) - } - for len(data[i]) >= 4 && data[i][3] == '-' { - i++ - if err := tc.PrintfLine("%s", data[i]); err != nil { - t.Errorf("printing to textproto failed: %s", err) - } - } - if data[i] == "221 Goodbye" { - return - } - read := false - for !read || data[i] == "354 Go ahead" { - msg, err := tc.ReadLine() - if _, err := bcmdbuf.Write([]byte(msg + "\r\n")); err != nil { - t.Errorf("write failed: %s", err) - } - read = true - if err != nil { - t.Errorf("Read error: %v", err) - return - } - if data[i] == "354 Go ahead" && msg == "." { - break - } - } - } - }(strings.Split(server, "\r\n")) - - err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n"}, []byte(strings.Replace(`From: test@example.com -To: other@example.com -Subject: SendMail test - -SendMail is working for me. -`, "\n", "\r\n", -1))) - if err == nil { - t.Errorf("Expected SendMail to be rejected due to a message injection attempt") - } - - err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com -To: other@example.com -Subject: SendMail test - -SendMail is working for me. -`, "\n", "\r\n", -1))) - if err != nil { - t.Errorf("%v", err) - } - - <-done - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - if client != actualcmds { - t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } -} - -var sendMailServer = `220 hello world -502 EH? -250 mx.google.com at your service -250 Sender ok -250 Receiver ok -354 Go ahead -250 Data ok -221 Goodbye -` - -var sendMailClient = `EHLO localhost -HELO localhost -MAIL FROM: -RCPT TO: -DATA -From: test@example.com -To: other@example.com -Subject: SendMail test - -SendMail is working for me. -. -QUIT -` - -func TestSendMailWithAuth(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Unable to create listener: %v", err) - } - defer func() { - _ = l.Close() - }() - - errCh := make(chan error) - go func() { - defer close(errCh) - conn, err := l.Accept() - if err != nil { - errCh <- fmt.Errorf("listener Accept: %w", err) - return - } - defer func() { - _ = conn.Close() - }() - - tc := textproto.NewConn(conn) - if err := tc.PrintfLine("220 hello world"); err != nil { - t.Errorf("textproto connetion print failed: %s", err) - } - msg, err := tc.ReadLine() - if err != nil { - errCh <- fmt.Errorf("textproto connection ReadLine error: %w", err) - return - } - const wantMsg = "EHLO localhost" - if msg != wantMsg { - errCh <- fmt.Errorf("unexpected response %q; want %q", msg, wantMsg) - return - } - err = tc.PrintfLine("250 mx.google.com at your service") - if err != nil { - errCh <- fmt.Errorf("textproto connection PrintfLine: %w", err) - return - } - }() - - err = SendMail(l.Addr().String(), PlainAuth("", "user", "pass", "smtp.google.com", false), "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com -To: other@example.com -Subject: SendMail test - -SendMail is working for me. -`, "\n", "\r\n", -1))) - if err == nil { - t.Error("SendMail: Server doesn't support AUTH, expected to get an error, but got none ") - return - } - if err.Error() != "smtp: server doesn't support AUTH" { - t.Errorf("Expected: smtp: server doesn't support AUTH, got: %s", err) - } - err = <-errCh - if err != nil { - t.Fatalf("server error: %v", err) - } -} - -func TestAuthFailed(t *testing.T) { - server := strings.Join(strings.Split(authFailedServer, "\n"), "\r\n") - client := strings.Join(strings.Split(authFailedClient, "\n"), "\r\n") - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - defer func() { - _ = c.Close() - }() - - c.tls = true - c.serverName = "smtp.google.com" - err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com", false)) - - if err == nil { - t.Error("Auth: expected error; got none") - } else if err.Error() != "535 Invalid credentials\nplease see www.example.com" { - t.Errorf("Auth: got error: %v, want: %s", err, "535 Invalid credentials\nplease see www.example.com") - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - if client != actualcmds { - t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } -} - -var authFailedServer = `220 hello world -250-mx.google.com at your service -250 AUTH LOGIN PLAIN -535-Invalid credentials -535 please see www.example.com -221 Goodbye -` - -var authFailedClient = `EHLO localhost -AUTH PLAIN AHVzZXIAcGFzcw== -* -QUIT -` - -func TestTLSClient(t *testing.T) { - if runtime.GOOS == "freebsd" || runtime.GOOS == "js" || runtime.GOOS == "wasip1" { - SkipFlaky(t, 19229) - } - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - errc := make(chan error) - go func() { - errc <- sendMail(ln.Addr().String()) - }() - conn, err := ln.Accept() - if err != nil { - t.Fatalf("failed to accept connection: %v", err) - } - defer func() { - _ = conn.Close() - }() - if err := serverHandle(conn, t); err != nil { - t.Fatalf("failed to handle connection: %v", err) - } - if err := <-errc; err != nil { - t.Fatalf("client error: %v", err) - } -} - -func TestTLSConnState(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - cfg := &tls.Config{ServerName: "example.com"} - testHookStartTLS(cfg) // set the RootCAs - if err := c.StartTLS(cfg); err != nil { - t.Errorf("StartTLS: %v", err) - return - } - cs, ok := c.TLSConnectionState() - if !ok { - t.Errorf("TLSConnectionState returned ok == false; want true") - return - } - if cs.Version == 0 || !cs.HandshakeComplete { - t.Errorf("ConnectionState = %#v; expect non-zero Version and HandshakeComplete", cs) - } - }() - <-clientDone - <-serverDone -} - -func TestClient_GetTLSConnectionState(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - cfg := &tls.Config{ServerName: "example.com"} - testHookStartTLS(cfg) // set the RootCAs - if err := c.StartTLS(cfg); err != nil { - t.Errorf("StartTLS: %v", err) - return - } - cs, err := c.GetTLSConnectionState() - if err != nil { - t.Errorf("failed to get TLSConnectionState: %s", err) - return - } - if cs.Version == 0 || !cs.HandshakeComplete { - t.Errorf("ConnectionState = %#v; expect non-zero Version and HandshakeComplete", cs) - } - }() - <-clientDone - <-serverDone -} - -func TestClient_GetTLSConnectionState_noTLS(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - _, err = c.GetTLSConnectionState() - if err == nil { - t.Error("GetTLSConnectionState: expected error; got nil") - return - } - }() - <-clientDone - <-serverDone -} - -func TestClient_GetTLSConnectionState_noConn(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - _ = c.Close() - _, err = c.GetTLSConnectionState() - if err == nil { - t.Error("GetTLSConnectionState: expected error; got nil") - return - } - }() - <-clientDone - <-serverDone -} - -func TestClient_GetTLSConnectionState_unableErr(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - c.tls = true - _, err = c.GetTLSConnectionState() - if err == nil { - t.Error("GetTLSConnectionState: expected error; got nil") - return - } - }() - <-clientDone - <-serverDone -} - -func TestClient_HasConnection(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - cfg := &tls.Config{ServerName: "example.com"} - testHookStartTLS(cfg) // set the RootCAs - if err := c.StartTLS(cfg); err != nil { - t.Errorf("StartTLS: %v", err) - return - } - if !c.HasConnection() { - t.Error("HasConnection: expected true; got false") - return - } - if err = c.Quit(); err != nil { - t.Errorf("closing connection failed: %s", err) - return - } - if c.HasConnection() { - t.Error("HasConnection: expected false; got true") - } - }() - <-clientDone - <-serverDone -} - -func TestClient_SetDSNMailReturnOption(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - c.SetDSNMailReturnOption("foo") - if c.dsnmrtype != "foo" { - t.Errorf("SetDSNMailReturnOption: expected %s; got %s", "foo", c.dsnrntype) - } - }() - <-clientDone - <-serverDone -} - -func TestClient_SetDSNRcptNotifyOption(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - c.SetDSNRcptNotifyOption("foo") - if c.dsnrntype != "foo" { - t.Errorf("SetDSNMailReturnOption: expected %s; got %s", "foo", c.dsnrntype) - } - }() - <-clientDone - <-serverDone -} - -func TestClient_UpdateDeadline(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err = serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if !c.HasConnection() { - t.Error("HasConnection: expected true; got false") - return - } - if err = c.UpdateDeadline(time.Millisecond * 20); err != nil { - t.Errorf("failed to update deadline: %s", err) - return - } - time.Sleep(time.Millisecond * 50) - if !c.HasConnection() { - t.Error("HasConnection: expected true; got false") - return - } - }() - <-clientDone - <-serverDone -} - -func newLocalListener(t *testing.T) net.Listener { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - ln, err = net.Listen("tcp6", "[::1]:0") - } - if err != nil { - t.Fatal(err) - } - return ln -} - -type smtpSender struct { - w io.Writer -} - -func (s smtpSender) send(f string) { - _, _ = s.w.Write([]byte(f + "\r\n")) -} - -// smtp server, finely tailored to deal with our own client only! -func serverHandle(c net.Conn, t *testing.T) error { - send := smtpSender{c}.send - send("220 127.0.0.1 ESMTP service ready") - s := bufio.NewScanner(c) - tf := func(config *tls.Config) error { - c = tls.Server(c, config) - defer func() { - _ = c.Close() - }() - return serverHandleTLS(c, t) - } - for s.Scan() { - switch s.Text() { - case "EHLO localhost": - send("250-127.0.0.1 ESMTP offers a warm hug of welcome") - send("250-STARTTLS") - send("250 Ok") - case "STARTTLS": - send("220 Go ahead") - keypair, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - return err - } - config := &tls.Config{Certificates: []tls.Certificate{keypair}} - return tf(config) - case "QUIT": - return nil - default: - t.Fatalf("unrecognized command: %q", s.Text()) - } - } - return s.Err() -} - -func serverHandleTLS(c net.Conn, t *testing.T) error { - send := smtpSender{c}.send - s := bufio.NewScanner(c) - for s.Scan() { - switch s.Text() { - case "EHLO localhost": - send("250 Ok") - case "MAIL FROM:": - send("250 Ok") - case "RCPT TO:": - send("250 Ok") - case "DATA": - send("354 send the mail data, end with .") - send("250 Ok") - case "Subject: test": - case "": - case "howdy!": - case ".": - case "QUIT": - send("221 127.0.0.1 Service closing transmission channel") - return nil - default: - t.Fatalf("unrecognized command during TLS: %q", s.Text()) - } - } - return s.Err() -} - -func init() { - testRootCAs := x509.NewCertPool() - testRootCAs.AppendCertsFromPEM(localhostCert) - testHookStartTLS = func(config *tls.Config) { - config.RootCAs = testRootCAs - } -} - -func sendMail(hostPort string) error { - from := "joe1@example.com" - to := []string{"joe2@example.com"} - return SendMail(hostPort, nil, from, to, []byte("Subject: test\n\nhowdy!")) -} - -var flaky = flag.Bool("flaky", false, "run known-flaky tests too") - -func SkipFlaky(t testing.TB, issue int) { - t.Helper() - if !*flaky { - t.Skipf("skipping known flaky test without the -flaky flag; see golang.org/issue/%d", issue) - } -} - - -*/ - // faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter From 59f2778a3853e9da3341f5d181beb7c9d98c5d59 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 17:44:31 +0100 Subject: [PATCH 32/45] Add tests for Client's SetDebugLog method These tests verify the behavior of the SetDebugLog method in various scenarios such as enabling and disabling debug logging and ensuring the logger type is as expected. This improves the robustness and reliability of the debug logging functionality in the Client class. --- smtp/smtp_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 11cc2a0..93cbf84 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -29,12 +29,15 @@ import ( "hash" "io" "net" + "os" "strings" "sync/atomic" "testing" "time" "golang.org/x/crypto/pbkdf2" + + "github.com/wneessen/go-mail/log" ) const ( @@ -3117,6 +3120,55 @@ func TestClient_Noop(t *testing.T) { }) } +func TestClient_SetDebugLog(t *testing.T) { + t.Run("set debug loggging to on with no logger defined", func(t *testing.T) { + client := &Client{} + client.SetDebugLog(true) + if !client.debug { + t.Fatalf("expected debug log to be true") + } + if client.logger == nil { + t.Fatalf("expected logger to be defined") + } + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.Stdlog") { + t.Errorf("expected logger to be of type *log.Stdlog, got: %T", client.logger) + } + }) + t.Run("set debug loggging to on should not override logger", func(t *testing.T) { + client := &Client{logger: log.NewJSON(os.Stderr, log.LevelDebug)} + client.SetDebugLog(true) + if !client.debug { + t.Fatalf("expected debug log to be true") + } + if client.logger == nil { + t.Fatalf("expected logger to be defined") + } + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { + t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) + } + }) + t.Run("set debug logggin to off with no logger defined", func(t *testing.T) { + client := &Client{} + client.SetDebugLog(false) + if client.debug { + t.Fatalf("expected debug log to be false") + } + if client.logger != nil { + t.Fatalf("expected logger to be nil") + } + }) + t.Run("set active logging to off should cancel out logger", func(t *testing.T) { + client := &Client{debug: true, logger: log.New(os.Stderr, log.LevelDebug)} + client.SetDebugLog(false) + if client.debug { + t.Fatalf("expected debug log to be false") + } + if client.logger != nil { + t.Fatalf("expected logger to be nil") + } + }) +} + // faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter From 9412f318745f79fd4a42e84030f964d4c54f31e0 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 17:47:59 +0100 Subject: [PATCH 33/45] Lock mutex when setting the logger Add mutex locking to ensure thread-safety when setting the logger in the `smtp` package. This prevents potential race conditions and ensures that the logger is updated consistently in concurrent operations. --- smtp/smtp.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/smtp/smtp.go b/smtp/smtp.go index 4841ec8..ed86ac9 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -587,7 +587,9 @@ func (c *Client) SetLogger(l log.Logger) { if l == nil { return } + c.mutex.Lock() c.logger = l + c.mutex.Unlock() } // SetLogAuthData enables logging of authentication data in the Client. From 5e3d14f842ab970c58f24971ad58f0881a22dff1 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 17:49:07 +0100 Subject: [PATCH 34/45] Add tests for Client's SetLogger method This commit introduces unit tests for the Client's SetLogger method. It verifies the correct logger type is set and ensures setting a nil logger does not override an existing logger. --- smtp/smtp_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 93cbf84..5b47291 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -3169,6 +3169,30 @@ func TestClient_SetDebugLog(t *testing.T) { }) } +func TestClient_SetLogger(t *testing.T) { + t.Run("set logger to Stdlog logger", func(t *testing.T) { + client := &Client{} + client.SetLogger(log.New(os.Stderr, log.LevelDebug)) + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.Stdlog") { + t.Errorf("expected logger to be of type *log.Stdlog, got: %T", client.logger) + } + }) + t.Run("set logger to JSONlog logger", func(t *testing.T) { + client := &Client{} + client.SetLogger(log.NewJSON(os.Stderr, log.LevelDebug)) + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { + t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) + } + }) + t.Run("nil logger should just return and not set/override", func(t *testing.T) { + client := &Client{logger: log.NewJSON(os.Stderr, log.LevelDebug)} + client.SetLogger(nil) + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { + t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) + } + }) +} + // faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter From 8fbd94a6755d97feb61a514777b8c51f7dd775b5 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 18:07:31 +0100 Subject: [PATCH 35/45] Add nil check in UpdateDeadline method Ensure that the connection is not nil before setting the deadline in the UpdateDeadline method. This prevents a potential runtime panic when attempting to set a deadline on a nil connection. --- smtp/smtp.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/smtp/smtp.go b/smtp/smtp.go index ed86ac9..8a278ee 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -622,6 +622,9 @@ func (c *Client) HasConnection() bool { func (c *Client) UpdateDeadline(timeout time.Duration) error { c.mutex.Lock() defer c.mutex.Unlock() + if c.conn == nil { + return errors.New("smtp: client has no connection") + } if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { return fmt.Errorf("smtp: failed to update deadline: %w", err) } From 007d214c680d42c2f3111be00fa757b4753ab1e6 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 18:08:56 +0100 Subject: [PATCH 36/45] Add comprehensive tests for SMTP client functionality Introduce new tests to verify behaviors of client methods including SetLogAuthData, SetDSNRcptNotifyOption, SetDSNMailReturnOption, HasConnection, and UpdateDeadline. Ensure these tests cover various scenarios such as successful operations, edge cases, and failure conditions. --- smtp/smtp_test.go | 203 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 5b47291..48a4515 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -3193,6 +3193,209 @@ func TestClient_SetLogger(t *testing.T) { }) } +func TestClient_SetLogAuthData(t *testing.T) { + t.Run("set log auth data to true", func(t *testing.T) { + client := &Client{} + client.SetLogAuthData() + if !client.logAuthData { + t.Fatalf("expected log auth data to be true") + } + }) +} + +func TestClient_SetDSNRcptNotifyOption(t *testing.T) { + tests := []string{"NEVER", "SUCCESS", "FAILURE", "DELAY"} + for _, test := range tests { + t.Run("set dsn rcpt notify option to "+test, func(t *testing.T) { + client := &Client{} + client.SetDSNRcptNotifyOption(test) + if !strings.EqualFold(client.dsnrntype, test) { + t.Errorf("expected dsn rcpt notify option to be %s, got %s", test, client.dsnrntype) + } + }) + } +} + +func TestClient_SetDSNMailReturnOption(t *testing.T) { + tests := []string{"HDRS", "FULL"} + for _, test := range tests { + t.Run("set dsn mail return option to "+test, func(t *testing.T) { + client := &Client{} + client.SetDSNMailReturnOption(test) + if !strings.EqualFold(client.dsnmrtype, test) { + t.Errorf("expected dsn mail return option to be %s, got %s", test, client.dsnmrtype) + } + }) + } +} + +func TestClient_HasConnection(t *testing.T) { + t.Run("client has connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if !client.HasConnection() { + t.Error("expected client to have a connection") + } + }) + t.Run("client has no connection", func(t *testing.T) { + client := &Client{} + if client.HasConnection() { + t.Error("expected client to have no connection") + } + }) + t.Run("client has no connection after close", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + if client.HasConnection() { + t.Error("expected client to have no connection after close") + } + }) + t.Run("client has no connection after quit", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Quit(); err != nil { + t.Errorf("failed to quit client: %s", err) + } + if client.HasConnection() { + t.Error("expected client to have no connection after quit") + } + }) +} + +func TestClient_UpdateDeadline(t *testing.T) { + t.Run("update deadline on sane client succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.UpdateDeadline(time.Millisecond * 500); err != nil { + t.Errorf("failed to update connection deadline: %s", err) + } + }) + t.Run("update deadline on no connection should fail", func(t *testing.T) { + client := &Client{} + var err error + if err = client.UpdateDeadline(time.Millisecond * 500); err == nil { + t.Error("expected client deadline update to fail on no connection") + } + expError := "smtp: client has no connection" + if !strings.EqualFold(err.Error(), expError) { + t.Errorf("expected error to be %q, got: %q", expError, err) + } + }) + t.Run("update deadline on closed client should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + if err = client.UpdateDeadline(time.Millisecond * 500); err == nil { + t.Error("expected client deadline update to fail on closed client") + } + }) +} + // faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter From 3b9085e19d5742b287d972d56dda1ea6a7fe2bab Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 18:17:36 +0100 Subject: [PATCH 37/45] Add tests for GetTLSConnectionState method in SMTP client This commit introduces four test cases for the GetTLSConnectionState method. These tests cover scenarios with a valid TLS connection, no connection, a non-TLS connection, and a non-TLS connection with the TLS flag set on the client. This ensures comprehensive validation of the method's behavior across different states. --- smtp/smtp_test.go | 116 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 48a4515..3fb84de 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -3396,6 +3396,122 @@ func TestClient_UpdateDeadline(t *testing.T) { }) } +func TestClient_GetTLSConnectionState(t *testing.T) { + t.Run("get state on sane client connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := getTLSConfig(t) + tlsConfig.MinVersion = tls.VersionTLS12 + tlsConfig.MaxVersion = tls.VersionTLS12 + if err = client.StartTLS(tlsConfig); err != nil { + t.Fatalf("failed to start TLS on client: %s", err) + } + state, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + if state == nil { + t.Error("expected TLS connection state to be non-nil") + } + if state.Version != tls.VersionTLS12 { + t.Errorf("expected TLS connection state version to be %d, got: %d", tls.VersionTLS12, state.Version) + } + }) + t.Run("get state on no connection", func(t *testing.T) { + client := &Client{} + _, err := client.GetTLSConnectionState() + if err == nil { + t.Fatal("expected client to have no tls connection state") + } + }) + t.Run("get state on non-tls client connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 DSN" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + _, err = client.GetTLSConnectionState() + if err == nil { + t.Error("expected client to have no tls connection state") + } + }) + t.Run("fail to get state on non-tls connection with tls flag set", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 DSN" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + client.tls = true + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + _, err = client.GetTLSConnectionState() + if err == nil { + t.Error("expected client to have no tls connection state") + } + }) +} + // faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter From c58aa354548ced165f2201261b5fe1ceea00e8c2 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 18:25:39 +0100 Subject: [PATCH 38/45] Add tests for Client debug logging behavior Introduce unit tests to verify the behavior of the debug logging in the Client. Confirm that logs are correctly produced when debug mode is enabled and appropriately suppressed when it is disabled. --- smtp/smtp_test.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 3fb84de..03777d5 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -3512,6 +3512,34 @@ func TestClient_GetTLSConnectionState(t *testing.T) { }) } +func TestClient_debugLog(t *testing.T) { + t.Run("debug log is enabled", func(t *testing.T) { + buffer := bytes.NewBuffer(nil) + logger := log.New(buffer, log.LevelDebug) + client := &Client{logger: logger, debug: true} + client.debugLog(log.DirClientToServer, "%s", "simple string") + client.debugLog(log.DirServerToClient, "%d", 1234) + want := "DEBUG: C --> S: simple string" + if !strings.Contains(buffer.String(), want) { + t.Errorf("expected debug log to contain %q, got: %q", want, buffer.String()) + } + want = "DEBUG: C <-- S: 1234" + if !strings.Contains(buffer.String(), want) { + t.Errorf("expected debug log to contain %q, got: %q", want, buffer.String()) + } + }) + t.Run("debug log is disable", func(t *testing.T) { + buffer := bytes.NewBuffer(nil) + logger := log.New(buffer, log.LevelDebug) + client := &Client{logger: logger, debug: false} + client.debugLog(log.DirClientToServer, "%s", "simple string") + client.debugLog(log.DirServerToClient, "%d", 1234) + if buffer.Len() > 0 { + t.Errorf("expected debug log to be empty, got: %q", buffer.String()) + } + }) +} + // faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter From 7da30e09e1c46cf3519905f75a22f1f2ac7f2a0d Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 18:35:10 +0100 Subject: [PATCH 39/45] Refactor smtp tests to improve clarity and error handling Removed unused variables and improved error handling in smtp_test.go. Adjusted to capture only error in auth.Next() calls, ensuring accurate validation. Added necessary error checks after creating new client connections to prevent test failures. --- smtp/smtp_test.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 03777d5..9011363 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -670,7 +670,7 @@ func TestLoginAuth(t *testing.T) { if !bytes.Equal([]byte(pass), resp) { t.Errorf("expected response to second challange to be: %q, got: %q", pass, resp) } - resp, err = auth.Next([]byte("nonsense"), true) + _, err = auth.Next([]byte("nonsense"), true) if err == nil { t.Error("expected third server challange to fail, but didn't") } @@ -818,7 +818,7 @@ func TestLoginAuth_noEnc(t *testing.T) { if !bytes.Equal([]byte(pass), resp) { t.Errorf("expected response to second challange to be: %q, got: %q", pass, resp) } - resp, err = auth.Next([]byte("nonsense"), true) + _, err = auth.Next([]byte("nonsense"), true) if err == nil { t.Error("expected third server challange to fail, but didn't") } @@ -910,7 +910,7 @@ func TestXOAuth2Auth(t *testing.T) { if !bytes.Equal([]byte(""), resp) { t.Errorf("expected server response to be empty, got: %q", resp) } - resp, err = auth.Next([]byte("nonsense"), false) + _, err = auth.Next([]byte("nonsense"), false) if err != nil { t.Errorf("failed on first server challange: %s", err) } @@ -1107,6 +1107,9 @@ func TestScramAuth(t *testing.T) { t.Fatalf("failed to dial TLS server: %v", err) } client, err = NewClient(conn, TestServerAddr) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } case false: var err error client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -1176,6 +1179,9 @@ func TestScramAuth(t *testing.T) { t.Fatalf("failed to dial TLS server: %v", err) } client, err = NewClient(conn, TestServerAddr) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } case false: var err error client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -3434,7 +3440,7 @@ func TestClient_GetTLSConnectionState(t *testing.T) { t.Fatalf("failed to get TLS connection state: %s", err) } if state == nil { - t.Error("expected TLS connection state to be non-nil") + t.Fatal("expected TLS connection state to be non-nil") } if state.Version != tls.VersionTLS12 { t.Errorf("expected TLS connection state version to be %d, got: %d", tls.VersionTLS12, state.Version) @@ -3543,7 +3549,6 @@ func TestClient_debugLog(t *testing.T) { // faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter - failOnRead bool failOnClose bool } @@ -3865,10 +3870,9 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server // fields are present. We have actual real authentication tests for all SCRAM modes in the // go-mail client_test.go type testSCRAMSMTP struct { - authMechanism string - nonce string - h func() hash.Hash - tlsServer bool + nonce string + h func() hash.Hash + tlsServer bool } func (s *testSCRAMSMTP) handleSCRAMAuth(conn net.Conn) { From 61353d51e56656d659762285e6d694cd48a64397 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 18:43:48 +0100 Subject: [PATCH 40/45] Fix variable declarations in test cases Changed variable declarations from '=' to ':=' to properly handle errors within the SMTP test cases. This ensures that errors are correctly captured and reported when writing to the EchoBuffer. --- smtp/smtp_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 9011363..9ac6118 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -3692,7 +3692,7 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server } time.Sleep(time.Millisecond) if props.EchoBuffer != nil { - if _, err = props.EchoBuffer.Write([]byte(data)); err != nil { + if _, err := props.EchoBuffer.Write([]byte(data)); err != nil { t.Errorf("failed write to echo buffer: %s", err) } } @@ -3795,7 +3795,7 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server break } if props.EchoBuffer != nil { - if _, err = props.EchoBuffer.Write([]byte(ddata)); err != nil { + if _, err := props.EchoBuffer.Write([]byte(ddata)); err != nil { t.Errorf("failed write to echo buffer: %s", err) } } From 2156fbc01ee1c6a90d026ade835e27390c70c8b3 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 18:51:14 +0100 Subject: [PATCH 41/45] Add mutex lock to handle concurrent SMTP test server connections Introduced a mutex to the SMTP test server properties to ensure thread-safe access when handling connections. This prevents race conditions and improves the reliability of the test server under concurrent load. --- smtp/smtp_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 9ac6118..ee1ae5a 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -31,6 +31,7 @@ import ( "net" "os" "strings" + "sync" "sync/atomic" "testing" "time" @@ -3592,6 +3593,7 @@ type serverProps struct { SSLListener bool TestSCRAM bool VRFYUserUnknown bool + mutex sync.Mutex } // simpleSMTPServer starts a simple TCP server that resonds to SMTP commands. @@ -3643,7 +3645,9 @@ func simpleSMTPServer(ctx context.Context, t *testing.T, props *serverProps) err } return fmt.Errorf("unable to accept connection: %w", err) } + props.mutex.Lock() handleTestServerConnection(connection, t, props) + props.mutex.Unlock() } } } From 08034e6ff894b5cf9a36e3c5325ece0dc6604b0a Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 18:58:31 +0100 Subject: [PATCH 42/45] Refactor echoBuffer parameter handling in tests Removed redundant mutex and streamlined anonymous goroutine syntax for test server setup by passing echoBuffer directly as a parameter. This change reduces unnecessary use of shared resources and simplifies the test code structure and fixes potential race conditions --- smtp/smtp_test.go | 52 ++++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index ee1ae5a..e10748b 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -31,7 +31,6 @@ import ( "net" "os" "strings" - "sync" "sync/atomic" "testing" "time" @@ -2236,9 +2235,9 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-8BITMIME\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func() { + go func(buf *bytes.Buffer) { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: echoBuffer, + EchoBuffer: buf, FeatureSet: featureSet, ListenPort: serverPort, }, @@ -2246,7 +2245,7 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to start test server: %s", err) return } - }() + }(echoBuffer) time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2274,9 +2273,9 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-SMTPUTF8\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func() { + go func(buf *bytes.Buffer) { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: echoBuffer, + EchoBuffer: buf, FeatureSet: featureSet, ListenPort: serverPort, }, @@ -2284,7 +2283,7 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to start test server: %s", err) return } - }() + }(echoBuffer) time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2312,9 +2311,9 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-SMTPUTF8\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func() { + go func(buf *bytes.Buffer) { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: echoBuffer, + EchoBuffer: buf, FeatureSet: featureSet, ListenPort: serverPort, }, @@ -2322,7 +2321,7 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to start test server: %s", err) return } - }() + }(echoBuffer) time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2350,9 +2349,9 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func() { + go func(buf *bytes.Buffer) { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: echoBuffer, + EchoBuffer: buf, FeatureSet: featureSet, ListenPort: serverPort, }, @@ -2360,7 +2359,7 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to start test server: %s", err) return } - }() + }(echoBuffer) time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2389,9 +2388,9 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func() { + go func(buf *bytes.Buffer) { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: echoBuffer, + EchoBuffer: buf, FeatureSet: featureSet, ListenPort: serverPort, }, @@ -2399,7 +2398,7 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to start test server: %s", err) return } - }() + }(echoBuffer) time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2491,9 +2490,9 @@ func TestClient_Rcpt(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func() { + go func(buf *bytes.Buffer) { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: echoBuffer, + EchoBuffer: buf, FeatureSet: featureSet, ListenPort: serverPort, }, @@ -2501,7 +2500,7 @@ func TestClient_Rcpt(t *testing.T) { t.Errorf("failed to start test server: %s", err) return } - }() + }(echoBuffer) time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { @@ -2783,9 +2782,9 @@ func TestSendMail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func() { + go func(buf *bytes.Buffer) { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: echoBuffer, + EchoBuffer: buf, FeatureSet: featureSet, ListenPort: serverPort, }, @@ -2793,7 +2792,7 @@ func TestSendMail(t *testing.T) { t.Errorf("failed to start test server: %s", err) return } - }() + }(echoBuffer) time.Sleep(time.Millisecond * 30) addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) testHookStartTLS = func(config *tls.Config) { @@ -2858,9 +2857,9 @@ func TestSendMail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func() { + go func(buf *bytes.Buffer) { if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: echoBuffer, + EchoBuffer: buf, FeatureSet: featureSet, ListenPort: serverPort, }, @@ -2868,7 +2867,7 @@ func TestSendMail(t *testing.T) { t.Errorf("failed to start test server: %s", err) return } - }() + }(echoBuffer) time.Sleep(time.Millisecond * 30) addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) testHookStartTLS = func(config *tls.Config) { @@ -3593,7 +3592,6 @@ type serverProps struct { SSLListener bool TestSCRAM bool VRFYUserUnknown bool - mutex sync.Mutex } // simpleSMTPServer starts a simple TCP server that resonds to SMTP commands. @@ -3645,9 +3643,7 @@ func simpleSMTPServer(ctx context.Context, t *testing.T, props *serverProps) err } return fmt.Errorf("unable to accept connection: %w", err) } - props.mutex.Lock() handleTestServerConnection(connection, t, props) - props.mutex.Unlock() } } } From 77175a2952c64f6eea5c035eba1cea74de463225 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 19:15:13 +0100 Subject: [PATCH 43/45] Refactor and relocate JSON logger tests for Go 1.21 compliance Removed JSON logger tests from smtp_test.go and relocated them to a new file smtp_121_test.go, ensuring compliance with Go 1.21. This change maintains test integrity while organizing tests by Go version compatibility. --- smtp/smtp_121_test.go | 59 +++++++++++++++++++++++++++++++++++++++++++ smtp/smtp_test.go | 27 -------------------- 2 files changed, 59 insertions(+), 27 deletions(-) create mode 100644 smtp/smtp_121_test.go diff --git a/smtp/smtp_121_test.go b/smtp/smtp_121_test.go new file mode 100644 index 0000000..ed722e0 --- /dev/null +++ b/smtp/smtp_121_test.go @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: Copyright 2010 The Go Authors. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2022-2023 The go-mail Authors +// +// Original net/smtp code from the Go stdlib by the Go Authors. +// Use of this source code is governed by a BSD-style +// LICENSE file that can be found in this directory. +// +// go-mail specific modifications by the go-mail Authors. +// Licensed under the MIT License. +// See [PROJECT ROOT]/LICENSES directory for more information. +// +// SPDX-License-Identifier: BSD-3-Clause AND MIT + +//go:build go1.21 +// +build go1.21 + +package smtp + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/wneessen/go-mail/log" +) + +func TestClient_SetDebugLog_JSON(t *testing.T) { + t.Run("set debug loggging to on should not override logger", func(t *testing.T) { + client := &Client{logger: log.NewJSON(os.Stderr, log.LevelDebug)} + client.SetDebugLog(true) + if !client.debug { + t.Fatalf("expected debug log to be true") + } + if client.logger == nil { + t.Fatalf("expected logger to be defined") + } + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { + t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) + } + }) +} + +func TestClient_SetLogger_JSON(t *testing.T) { + t.Run("set logger to JSONlog logger", func(t *testing.T) { + client := &Client{} + client.SetLogger(log.NewJSON(os.Stderr, log.LevelDebug)) + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { + t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) + } + }) + t.Run("nil logger should just return and not set/override", func(t *testing.T) { + client := &Client{logger: log.NewJSON(os.Stderr, log.LevelDebug)} + client.SetLogger(nil) + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { + t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) + } + }) +} diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index e10748b..4d6316a 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -3140,19 +3140,6 @@ func TestClient_SetDebugLog(t *testing.T) { t.Errorf("expected logger to be of type *log.Stdlog, got: %T", client.logger) } }) - t.Run("set debug loggging to on should not override logger", func(t *testing.T) { - client := &Client{logger: log.NewJSON(os.Stderr, log.LevelDebug)} - client.SetDebugLog(true) - if !client.debug { - t.Fatalf("expected debug log to be true") - } - if client.logger == nil { - t.Fatalf("expected logger to be defined") - } - if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { - t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) - } - }) t.Run("set debug logggin to off with no logger defined", func(t *testing.T) { client := &Client{} client.SetDebugLog(false) @@ -3183,20 +3170,6 @@ func TestClient_SetLogger(t *testing.T) { t.Errorf("expected logger to be of type *log.Stdlog, got: %T", client.logger) } }) - t.Run("set logger to JSONlog logger", func(t *testing.T) { - client := &Client{} - client.SetLogger(log.NewJSON(os.Stderr, log.LevelDebug)) - if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { - t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) - } - }) - t.Run("nil logger should just return and not set/override", func(t *testing.T) { - client := &Client{logger: log.NewJSON(os.Stderr, log.LevelDebug)} - client.SetLogger(nil) - if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { - t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) - } - }) } func TestClient_SetLogAuthData(t *testing.T) { From f7bdd8fffc66a2fd380d8a7035dba0ba709360dd Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 19:21:50 +0100 Subject: [PATCH 44/45] Refactor error variable in smtp test Renamed 'err' to 'berr' to avoid shadowing outer variable. This change ensures clearer error handling and avoids potential issues with variable scope and readability in the smtp_test.go file. --- smtp/smtp_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 4d6316a..15ea7c5 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -3665,8 +3665,8 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server } time.Sleep(time.Millisecond) if props.EchoBuffer != nil { - if _, err := props.EchoBuffer.Write([]byte(data)); err != nil { - t.Errorf("failed write to echo buffer: %s", err) + if _, berr := props.EchoBuffer.Write([]byte(data)); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) } } From 800c266ccba97415266df0ccc0cbdf0e71cabcc3 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 19:53:14 +0100 Subject: [PATCH 45/45] Refactor test server code for thread safety Moved serverProps outside goroutines to improve code readability and maintainability. Added a RWMutex to serverProps to ensure thread-safe access to EchoBuffer, preventing race conditions during concurrent writes. --- smtp/smtp_test.go | 160 ++++++++++++++++++++++++++-------------------- 1 file changed, 92 insertions(+), 68 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 15ea7c5..931b74e 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -31,6 +31,7 @@ import ( "net" "os" "strings" + "sync" "sync/atomic" "testing" "time" @@ -2235,17 +2236,17 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-8BITMIME\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2261,7 +2262,9 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: BODY=8BITMIME" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[5], expected) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } @@ -2273,17 +2276,17 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-SMTPUTF8\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2299,7 +2302,9 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: SMTPUTF8" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[5], expected) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } @@ -2311,17 +2316,17 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-SMTPUTF8\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2337,7 +2342,9 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: SMTPUTF8" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[5], expected) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } @@ -2349,17 +2356,17 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2376,7 +2383,9 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: RET=FULL" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[5], expected) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } @@ -2388,17 +2397,17 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2415,7 +2424,9 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: BODY=8BITMIME SMTPUTF8 RET=FULL" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[7], expected) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[7]) } @@ -2490,17 +2501,17 @@ func TestClient_Rcpt(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { @@ -2519,7 +2530,9 @@ func TestClient_Rcpt(t *testing.T) { t.Error("recpient address with newlines should fail") } expected := "RCPT TO: NOTIFY=SUCCESS" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[5], expected) { t.Errorf("expected rcpt to command to be %q, but sent %q", expected, resp[5]) } @@ -2782,17 +2795,17 @@ func TestSendMail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) testHookStartTLS = func(config *tls.Config) { @@ -2806,7 +2819,9 @@ func TestSendMail(t *testing.T) { []byte("test message")); err != nil { t.Fatalf("failed to send mail: %s", err) } + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if len(resp)-1 != len(want) { t.Fatalf("expected %d lines, but got %d", len(want), len(resp)) } @@ -2857,17 +2872,17 @@ func TestSendMail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) testHookStartTLS = func(config *tls.Config) { @@ -2887,7 +2902,9 @@ Goodbye.`) if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, message); err != nil { t.Fatalf("failed to send mail: %s", err) } + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if len(resp)-1 != len(want) { t.Errorf("expected %d lines, but got %d", len(want), len(resp)) } @@ -3542,6 +3559,7 @@ func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", " // serverProps represents the configuration properties for the SMTP server. type serverProps struct { + BufferMutex sync.RWMutex EchoBuffer io.Writer FailOnAuth bool FailOnDataInit bool @@ -3640,9 +3658,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server t.Logf("failed to write line: %s", err) } if props.EchoBuffer != nil { - if _, err := props.EchoBuffer.Write([]byte(data + "\r\n")); err != nil { - t.Errorf("failed write to echo buffer: %s", err) + props.BufferMutex.Lock() + if _, berr := props.EchoBuffer.Write([]byte(data + "\r\n")); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) } + props.BufferMutex.Unlock() } _ = writer.Flush() } @@ -3665,9 +3685,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server } time.Sleep(time.Millisecond) if props.EchoBuffer != nil { + props.BufferMutex.Lock() if _, berr := props.EchoBuffer.Write([]byte(data)); berr != nil { t.Errorf("failed write to echo buffer: %s", berr) } + props.BufferMutex.Unlock() } var datastring string @@ -3768,9 +3790,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server break } if props.EchoBuffer != nil { - if _, err := props.EchoBuffer.Write([]byte(ddata)); err != nil { - t.Errorf("failed write to echo buffer: %s", err) + props.BufferMutex.Lock() + if _, berr := props.EchoBuffer.Write([]byte(ddata)); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) } + props.BufferMutex.Unlock() } ddata = strings.TrimSpace(ddata) if ddata == "." {