diff --git a/client_test.go b/client_test.go index 8bad1e0..afc5743 100644 --- a/client_test.go +++ b/client_test.go @@ -9,6 +9,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "net" "os" "strconv" @@ -1269,3 +1270,114 @@ func getTestConnectionWithDSN(auth bool) (*Client, error) { } return c, nil } + +func TestXOAuth2OK(t *testing.T) { + server := []string{ + "220 Fake server ready ESMTP", + "250-fake.server", + "250-AUTH LOGIN XOAUTH2", + "250 8BITMIME", + "250 OK", + "235 2.7.0 Accepted", + "250 OK", + "221 OK", + } + var wrote strings.Builder + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(strings.Join(server, "\r\n")), + &wrote, + } + c, err := NewClient("fake.host", + WithDialContextFunc(getFakeDialFunc(fake)), + WithTLSPolicy(TLSOpportunistic), + WithSMTPAuth(SMTPAuthXOAUTH2), + WithUsername("user"), + WithPassword("token")) + if err != nil { + t.Fatalf("unable to create new client: %v", err) + } + if err := c.DialWithContext(context.Background()); err != nil { + t.Fatalf("unexpected dial error: %v", err) + } + if err := c.Close(); err != nil { + t.Fatalf("disconnect from test server failed: %v", err) + } + if !strings.Contains(wrote.String(), "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n") { + t.Fatalf("got %q; want AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n", wrote.String()) + } +} + +func TestXOAuth2Unsupported(t *testing.T) { + server := []string{ + "220 Fake server ready ESMTP", + "250-fake.server", + "250-AUTH LOGIN PLAIN", + "250 8BITMIME", + "250 OK", + "250 OK", + "221 OK", + } + var wrote strings.Builder + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(strings.Join(server, "\r\n")), + &wrote, + } + c, err := NewClient("fake.host", + WithDialContextFunc(getFakeDialFunc(fake)), + WithTLSPolicy(TLSOpportunistic), + WithSMTPAuth(SMTPAuthXOAUTH2)) + if err != nil { + t.Fatalf("unable to create new client: %v", err) + } + if err := c.DialWithContext(context.Background()); err == nil { + t.Fatal("expected dial error got nil") + } else { + if !errors.Is(err, ErrXOauth2AuthNotSupported) { + t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) + } + } + if err := c.Close(); err != nil { + t.Fatalf("disconnect from test server failed: %v", err) + } + client := strings.Split(wrote.String(), "\r\n") + if len(client) != 5 { + t.Fatalf("unexpected number of client requests got %d; want 5", len(client)) + } + if !strings.HasPrefix(client[0], "EHLO") { + t.Fatalf("expected EHLO, got %q", client[0]) + } + if client[1] != "NOOP" { + t.Fatalf("expected NOOP, got %q", client[1]) + } + if client[2] != "NOOP" { + t.Fatalf("expected NOOP, got %q", client[2]) + } + if client[3] != "QUIT" { + t.Fatalf("expected QUIT, got %q", client[3]) + } +} + +func getFakeDialFunc(conn net.Conn) DialContextFunc { + return func(ctx context.Context, network, address string) (net.Conn, error) { + return conn, nil + } +} + +type faker struct { + io.ReadWriter +} + +func (f faker) Close() error { return nil } +func (f faker) LocalAddr() net.Addr { return nil } +func (f faker) RemoteAddr() net.Addr { return nil } +func (f faker) SetDeadline(time.Time) error { return nil } +func (f faker) SetReadDeadline(time.Time) error { return nil } +func (f faker) SetWriteDeadline(time.Time) error { return nil } diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index e09cdf3..62d0f0b 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -200,7 +200,7 @@ func TestAuthLogin(t *testing.T) { } } -func TestXOAuthOK(t *testing.T) { +func TestXOAuth2OK(t *testing.T) { server := []string{ "220 Fake server ready ESMTP", "250-fake.server",