diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 0d47760..451a349 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -16,11 +16,15 @@ package smtp import ( "bufio" "bytes" + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/base64" "flag" "fmt" + "hash" "io" "net" "net/textproto" @@ -30,6 +34,8 @@ import ( "testing" "time" + "golang.org/x/crypto/pbkdf2" + "github.com/wneessen/go-mail/log" ) @@ -374,7 +380,7 @@ func TestAuthSCRAMSHA1_OK(t *testing.T) { port := "2585" go func() { - startSMTPServer(false, hostname, port) + startSMTPServer(false, hostname, port, sha1.New) }() time.Sleep(time.Millisecond * 500) @@ -399,7 +405,7 @@ func TestAuthSCRAMSHA256_OK(t *testing.T) { port := "2586" go func() { - startSMTPServer(false, hostname, port) + startSMTPServer(false, hostname, port, sha256.New) }() time.Sleep(time.Millisecond * 500) @@ -419,12 +425,80 @@ func TestAuthSCRAMSHA256_OK(t *testing.T) { } } +func TestAuthSCRAMSHA1PLUS_OK(t *testing.T) { + hostname := "127.0.0.1" + port := "2590" + + go func() { + startSMTPServer(true, hostname, port, sha1.New) + }() + time.Sleep(time.Millisecond * 500) + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + fmt.Printf("error creating TLS cert: %s", err) + return + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} + + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + + tlsConnState := conn.ConnectionState() + if err = client.Auth(ScramSHA1PlusAuth("username", "password", &tlsConnState)); err != nil { + t.Errorf("failed to authenticate: %v", err) + } +} + +func TestAuthSCRAMSHA256PLUS_OK(t *testing.T) { + hostname := "127.0.0.1" + port := "2591" + + go func() { + startSMTPServer(true, hostname, port, sha256.New) + }() + time.Sleep(time.Millisecond * 500) + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + fmt.Printf("error creating TLS cert: %s", err) + return + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} + + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + + tlsConnState := conn.ConnectionState() + if err = client.Auth(ScramSHA256PlusAuth("username", "password", &tlsConnState)); err != nil { + t.Errorf("failed to authenticate: %v", err) + } +} + func TestAuthSCRAMSHA1_fail(t *testing.T) { hostname := "127.0.0.1" port := "2587" go func() { - startSMTPServer(true, hostname, port) + startSMTPServer(false, hostname, port, sha1.New) }() time.Sleep(time.Millisecond * 500) @@ -439,7 +513,7 @@ func TestAuthSCRAMSHA1_fail(t *testing.T) { if err = client.Hello(hostname); err != nil { t.Errorf("failed to send HELO: %v", err) } - if err = client.Auth(ScramSHA1Auth("username", "password")); err == nil { + if err = client.Auth(ScramSHA1Auth("username", "invalid")); err == nil { t.Errorf("expected auth error, got nil") } } @@ -449,7 +523,7 @@ func TestAuthSCRAMSHA256_fail(t *testing.T) { port := "2588" go func() { - startSMTPServer(true, hostname, port) + startSMTPServer(false, hostname, port, sha256.New) }() time.Sleep(time.Millisecond * 500) @@ -464,7 +538,73 @@ func TestAuthSCRAMSHA256_fail(t *testing.T) { if err = client.Hello(hostname); err != nil { t.Errorf("failed to send HELO: %v", err) } - if err = client.Auth(ScramSHA256Auth("username", "password")); err == nil { + if err = client.Auth(ScramSHA256Auth("username", "invalid")); err == nil { + t.Errorf("expected auth error, got nil") + } +} + +func TestAuthSCRAMSHA1PLUS_fail(t *testing.T) { + hostname := "127.0.0.1" + port := "2592" + + go func() { + startSMTPServer(true, hostname, port, sha1.New) + }() + time.Sleep(time.Millisecond * 500) + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + fmt.Printf("error creating TLS cert: %s", err) + return + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} + + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + tlsConnState := conn.ConnectionState() + if err = client.Auth(ScramSHA1PlusAuth("username", "invalid", &tlsConnState)); err == nil { + t.Errorf("expected auth error, got nil") + } +} + +func TestAuthSCRAMSHA256PLUS_fail(t *testing.T) { + hostname := "127.0.0.1" + port := "2593" + + go func() { + startSMTPServer(true, hostname, port, sha1.New) + }() + time.Sleep(time.Millisecond * 500) + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + fmt.Printf("error creating TLS cert: %s", err) + return + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} + + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) + if err != nil { + t.Errorf("failed to dial server: %v", err) + } + client, err := NewClient(conn, hostname) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + if err = client.Hello(hostname); err != nil { + t.Errorf("failed to send HELO: %v", err) + } + tlsConnState := conn.ConnectionState() + if err = client.Auth(ScramSHA256PlusAuth("username", "invalid", &tlsConnState)); err == nil { t.Errorf("expected auth error, got nil") } } @@ -1652,7 +1792,8 @@ type testSCRAMSMTPServer struct { nonce string hostname string port string - shouldFail bool + tlsServer bool + h func() hash.Hash } func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) { @@ -1705,7 +1846,8 @@ func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) { } authMechanism := parts[1] - if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" { + if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" && + authMechanism != "SCRAM-SHA-1-PLUS" && authMechanism != "SCRAM-SHA-256-PLUS" { _ = writeLine("504 Unrecognized authentication mechanism") return } @@ -1729,6 +1871,7 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { } return writer.Flush() } + var authMsg string data, err := reader.ReadString('\n') if err != nil { @@ -1746,10 +1889,14 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { _ = writeLine("535 Authentication failed - expected 4 parts") return } - if splits[0] != "n" { + if !s.tlsServer && splits[0] != "n" { _ = writeLine("535 Authentication failed - expected n to be in the first part") return } + if s.tlsServer && !strings.HasPrefix(splits[0], "p=") { + _ = writeLine("535 Authentication failed - expected p= to be in the first part") + return + } if splits[2] != "n=username" { _ = writeLine("535 Authentication failed - expected n=username to be in the third part") return @@ -1758,6 +1905,8 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { _ = writeLine("535 Authentication failed - expected r= to be in the fourth part") return } + authMsg = splits[2] + "," + splits[3] + clientNonce := s.extractNonce(string(decodedMessage)) if clientNonce == "" { _ = writeLine("535 Authentication failed") @@ -1765,8 +1914,10 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { } s.nonce = clientNonce + "server_nonce" - serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=0", s.nonce, "salt") + serverFirstMessage := fmt.Sprintf("r=%s,s=%s,i=4096", s.nonce, + base64.StdEncoding.EncodeToString([]byte("salt"))) _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFirstMessage)))) + authMsg = authMsg + "," + serverFirstMessage data, err = reader.ReadString('\n') if err != nil { @@ -1780,10 +1931,32 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { return } splits = strings.Split(string(decodedFinalMessage), ",") - if splits[0] != "c=biws" { + + if !s.tlsServer && splits[0] != "c=biws" { _ = writeLine("535 Authentication failed - expected c=biws to be in the first part") return } + if s.tlsServer { + if !strings.HasPrefix(splits[0], "c=") { + _ = writeLine("535 Authentication failed - expected c= to be in the first part") + return + } + channelBind, err := base64.StdEncoding.DecodeString(splits[0][2:]) + if err != nil { + _ = writeLine("535 Authentication failed - base64 channel bind is not valid - " + err.Error()) + return + } + if !strings.HasPrefix(string(channelBind), "p=") { + _ = writeLine("535 Authentication failed - expected channel binding to start with p=-") + return + } + cbType := string(channelBind[2:]) + if !strings.HasPrefix(cbType, "tls-unique") && !strings.HasPrefix(cbType, "tls-exporter") { + _ = writeLine("535 Authentication failed - expected channel binding type tls-unique or tls-exporter") + return + } + } + if !strings.HasPrefix(splits[1], "r=") { _ = writeLine("535 Authentication failed - expected r to be in the second part") return @@ -1797,10 +1970,27 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { return } - if s.shouldFail { + authMsg = authMsg + "," + splits[0] + "," + splits[1] + saltedPwd := pbkdf2.Key([]byte("password"), []byte("salt"), 4096, s.h().Size(), s.h) + mac := hmac.New(s.h, saltedPwd) + mac.Write([]byte("Server Key")) + skey := mac.Sum(nil) + mac.Reset() + + mac = hmac.New(s.h, skey) + mac.Write([]byte(authMsg)) + ssig := mac.Sum(nil) + mac.Reset() + + serverFinalMessage := fmt.Sprintf("v=%s", base64.StdEncoding.EncodeToString(ssig)) + _ = writeLine(fmt.Sprintf("334 %s", base64.StdEncoding.EncodeToString([]byte(serverFinalMessage)))) + + _, err = reader.ReadString('\n') + if err != nil { _ = writeLine("535 Authentication failed") return } + _ = writeLine("235 Authentication successful") } @@ -1814,11 +2004,12 @@ func (s *testSCRAMSMTPServer) extractNonce(message string) string { return "" } -func startSMTPServer(shouldFail bool, hostname, port string) { +func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash) { server := &testSCRAMSMTPServer{ - hostname: hostname, - port: port, - shouldFail: shouldFail, + hostname: hostname, + port: port, + tlsServer: tlsServer, + h: h, } listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", hostname, port)) if err != nil { @@ -1827,12 +2018,23 @@ func startSMTPServer(shouldFail bool, hostname, port string) { defer func() { _ = listener.Close() }() + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + fmt.Printf("error creating TLS cert: %s", err) + return + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}} + for { conn, err := listener.Accept() if err != nil { fmt.Printf("Failed to accept connection: %v", err) continue } + if server.tlsServer { + conn = tls.Server(conn, &tlsConfig) + } go server.handleConnection(conn) } }