diff --git a/auth/auth_test.go b/auth/auth_test.go deleted file mode 100644 index 4df16ce..0000000 --- a/auth/auth_test.go +++ /dev/null @@ -1,104 +0,0 @@ -// SPDX-FileCopyrightText: 2022 Winni Neessen -// -// SPDX-License-Identifier: MIT - -package auth - -import ( - "bytes" - "net/smtp" - "testing" -) - -func TestAuth(t *testing.T) { - type authTest struct { - auth smtp.Auth - challenges []string - name string - responses []string - shouldfail []bool - } - - authTests := []authTest{ - { - LoginAuth("user", "pass", "testserver"), - []string{"Username:", "Password:", "Invalid:"}, - "LOGIN", - []string{"", "user", "pass", ""}, - []bool{false, false, true}, - }, - } - -testLoop: - for i, test := range authTests { - name, resp, err := test.auth.Start(&smtp.ServerInfo{Name: "testserver", TLS: true, Auth: nil}) - if name != test.name { - t.Errorf("#%d got name %s, expected %s", i, name, test.name) - } - if !bytes.Equal(resp, []byte(test.responses[0])) { - t.Errorf("#%d got response %s, expected %s", i, resp, test.responses[0]) - } - if err != nil { - t.Errorf("#%d error: %s", i, err) - } - for j := range test.challenges { - challenge := []byte(test.challenges[j]) - expected := []byte(test.responses[j+1]) - resp, err := test.auth.Next(challenge, true) - if err != nil && !test.shouldfail[j] { - t.Errorf("#%d error: %s", i, err) - continue testLoop - } - if !bytes.Equal(resp, expected) { - t.Errorf("#%d got %s, expected %s", i, resp, expected) - continue testLoop - } - } - } -} - -func TestAuthLogin(t *testing.T) { - tests := []struct { - authName string - server *smtp.ServerInfo - err string - }{ - { - authName: "servername", - server: &smtp.ServerInfo{Name: "servername", TLS: true}, - }, - { - // OK to use LoginAuth on localhost without TLS - authName: "localhost", - server: &smtp.ServerInfo{Name: "localhost", TLS: false}, - }, - { - // NOT OK on non-localhost, even if server says PLAIN is OK. - // (We don't know that the server is the real server.) - authName: "servername", - server: &smtp.ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, - err: "unencrypted connection", - }, - { - authName: "servername", - server: &smtp.ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, - err: "unencrypted connection", - }, - { - authName: "servername", - server: &smtp.ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", - }, - } - for i, tt := range tests { - auth := LoginAuth("foo", "bar", tt.authName) - _, _, err := auth.Start(tt.server) - got := "" - if err != nil { - got = err.Error() - } - if got != tt.err { - t.Errorf("%d. got error = %q; want %q", i, got, tt.err) - } - } -} diff --git a/client.go b/client.go index fd0e7f7..2bee9f2 100644 --- a/client.go +++ b/client.go @@ -10,12 +10,11 @@ import ( "errors" "fmt" "net" - "net/smtp" "os" "strings" "time" - "github.com/wneessen/go-mail/auth" + "github.com/wneessen/go-mail/smtp" ) // Defaults @@ -593,7 +592,7 @@ func (c *Client) auth() error { if !strings.Contains(sat, string(SMTPAuthLogin)) { return ErrLoginAuthNotSupported } - c.sa = auth.LoginAuth(c.user, c.pass, c.host) + c.sa = smtp.LoginAuth(c.user, c.pass, c.host) case SMTPAuthCramMD5: if !strings.Contains(sat, string(SMTPAuthCramMD5)) { return ErrCramMD5AuthNotSupported diff --git a/client_test.go b/client_test.go index b3bed52..c85a91c 100644 --- a/client_test.go +++ b/client_test.go @@ -9,14 +9,13 @@ import ( "crypto/tls" "errors" "fmt" - "net/smtp" "os" "strconv" "strings" "testing" "time" - "github.com/wneessen/go-mail/auth" + "github.com/wneessen/go-mail/smtp" ) // DefaultHost is used as default hostname for the Client @@ -496,7 +495,7 @@ func TestSetSMTPAuthCustom(t *testing.T) { }{ {"SMTPAuth: PLAIN", smtp.PlainAuth("", "", "", ""), "PLAIN", false}, {"SMTPAuth: CRAM-MD5", smtp.CRAMMD5Auth("", ""), "CRAM-MD5", false}, - {"SMTPAuth: LOGIN", auth.LoginAuth("", "", ""), "LOGIN", false}, + {"SMTPAuth: LOGIN", smtp.LoginAuth("", "", ""), "LOGIN", false}, } si := smtp.ServerInfo{TLS: true} for _, tt := range tests { @@ -584,7 +583,7 @@ func TestClient_DialWithContextInvalidAuth(t *testing.T) { } c.user = "invalid" c.pass = "invalid" - c.SetSMTPAuthCustom(auth.LoginAuth("invalid", "invalid", "invalid")) + c.SetSMTPAuthCustom(smtp.LoginAuth("invalid", "invalid", "invalid")) ctx := context.Background() if err := c.DialWithContext(ctx); err == nil { t.Errorf("dial succeeded but was supposed to fail") diff --git a/smtp/auth.go b/smtp/auth.go new file mode 100644 index 0000000..70cf7aa --- /dev/null +++ b/smtp/auth.go @@ -0,0 +1,31 @@ +package smtp + +// Auth is implemented by an SMTP authentication mechanism. +type Auth interface { + // Start begins an authentication with a server. + // It returns the name of the authentication protocol + // and optionally data to include in the initial AUTH message + // sent to the server. + // If it returns a non-nil error, the SMTP client aborts + // the authentication attempt and closes the connection. + Start(server *ServerInfo) (proto string, toServer []byte, err error) + + // Next continues the authentication. The server has just sent + // the fromServer data. If more is true, the server expects a + // response, which Next should return as toServer; otherwise + // Next should return toServer == nil. + // If Next returns a non-nil error, the SMTP client aborts + // the authentication attempt and closes the connection. + Next(fromServer []byte, more bool) (toServer []byte, err error) +} + +// ServerInfo records information about an SMTP server. +type ServerInfo struct { + Name string // SMTP server name + TLS bool // using TLS, with valid certificate for Name + Auth []string // advertised authentication mechanisms +} + +func isLocalhost(name string) bool { + return name == "localhost" || name == "127.0.0.1" || name == "::1" +} diff --git a/smtp/auth_cram_md5.go b/smtp/auth_cram_md5.go new file mode 100644 index 0000000..2f91289 --- /dev/null +++ b/smtp/auth_cram_md5.go @@ -0,0 +1,34 @@ +package smtp + +import ( + "crypto/hmac" + "crypto/md5" + "fmt" +) + +// cramMD5Auth is the type that satisfies the Auth interface for the "SMTP CRAM_MD5" auth +type cramMD5Auth struct { + username, secret string +} + +// CRAMMD5Auth returns an Auth that implements the CRAM-MD5 authentication +// mechanism as defined in RFC 2195. +// The returned Auth uses the given username and secret to authenticate +// to the server using the challenge-response mechanism. +func CRAMMD5Auth(username, secret string) Auth { + return &cramMD5Auth{username, secret} +} + +func (a *cramMD5Auth) Start(server *ServerInfo) (string, []byte, error) { + return "CRAM-MD5", nil, nil +} + +func (a *cramMD5Auth) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + d := hmac.New(md5.New, []byte(a.secret)) + d.Write(fromServer) + s := make([]byte, 0, d.Size()) + return fmt.Appendf(nil, "%s %x", a.username, d.Sum(s)), nil + } + return nil, nil +} diff --git a/auth/login.go b/smtp/auth_login.go similarity index 88% rename from auth/login.go rename to smtp/auth_login.go index 51f9179..6f92bab 100644 --- a/auth/login.go +++ b/smtp/auth_login.go @@ -3,14 +3,14 @@ // SPDX-License-Identifier: MIT // Package auth implements the LOGIN and MD5-DIGEST smtp authentication mechanisms -package auth +package smtp import ( "errors" "fmt" - "net/smtp" ) +// loginAuth is the type that satisfies the Auth interface for the "SMTP LOGIN" auth type loginAuth struct { username, password string host string @@ -35,15 +35,11 @@ const ( // LoginAuth will only send the credentials if the connection is using TLS // or is connected to localhost. Otherwise authentication will fail with an // error, without sending the credentials. -func LoginAuth(username, password, host string) smtp.Auth { +func LoginAuth(username, password, host string) Auth { return &loginAuth{username, password, host} } -func isLocalhost(name string) bool { - return name == "localhost" || name == "127.0.0.1" || name == "::1" -} - -func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { +func (a *loginAuth) Start(server *ServerInfo) (string, []byte, error) { // Must have TLS, or else localhost server. // Note: If TLS is not true, then we can't trust ANYTHING in ServerInfo. // In particular, it doesn't matter if the server advertises LOGIN auth. diff --git a/smtp/auth_plain.go b/smtp/auth_plain.go new file mode 100644 index 0000000..035699b --- /dev/null +++ b/smtp/auth_plain.go @@ -0,0 +1,47 @@ +package smtp + +import ( + "errors" +) + +// plainAuth is the type that satisfies the Auth interface for the "SMTP PLAIN" auth +type plainAuth struct { + identity, username, password string + host string +} + +// PlainAuth returns an Auth that implements the PLAIN authentication +// mechanism as defined in RFC 4616. The returned Auth uses the given +// username and password to authenticate to host and act as identity. +// Usually identity should be the empty string, to act as username. +// +// PlainAuth will only send the credentials if the connection is using TLS +// or is connected to localhost. Otherwise authentication will fail with an +// error, without sending the credentials. +func PlainAuth(identity, username, password, host string) Auth { + return &plainAuth{identity, username, password, host} +} + +func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { + // Must have TLS, or else localhost server. + // Note: If TLS is not true, then we can't trust ANYTHING in ServerInfo. + // In particular, it doesn't matter if the server advertises PLAIN auth. + // That might just be the attacker saying + // "it's ok, you can trust me with your password." + if !server.TLS && !isLocalhost(server.Name) { + return "", nil, errors.New("unencrypted connection") + } + if server.Name != a.host { + return "", nil, errors.New("wrong host name") + } + resp := []byte(a.identity + "\x00" + a.username + "\x00" + a.password) + return "PLAIN", resp, nil +} + +func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if more { + // We've already sent everything. + return nil, errors.New("unexpected server challenge") + } + return nil, nil +} diff --git a/smtp/example_test.go b/smtp/example_test.go new file mode 100644 index 0000000..452f085 --- /dev/null +++ b/smtp/example_test.go @@ -0,0 +1,84 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package smtp_test + +import ( + "fmt" + "log" + + "github.com/wneessen/go-mail/smtp" +) + +func Example() { + // Connect to the remote SMTP server. + c, err := smtp.Dial("mail.example.com:25") + if err != nil { + log.Fatal(err) + } + + // Set the sender and recipient first + if err := c.Mail("sender@example.org"); err != nil { + log.Fatal(err) + } + if err := c.Rcpt("recipient@example.net"); err != nil { + log.Fatal(err) + } + + // Send the email body. + wc, err := c.Data() + if err != nil { + log.Fatal(err) + } + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + log.Fatal(err) + } + err = wc.Close() + if err != nil { + log.Fatal(err) + } + + // Send the QUIT command and close the connection. + err = c.Quit() + if err != nil { + log.Fatal(err) + } +} + +// variables to make ExamplePlainAuth compile, without adding +// unnecessary noise there. +var ( + from = "gopher@example.net" + msg = []byte("dummy message") + recipients = []string{"foo@example.com"} +) + +func ExamplePlainAuth() { + // hostname is used by PlainAuth to validate the TLS certificate. + hostname := "mail.example.com" + auth := smtp.PlainAuth("", "user@example.com", "password", hostname) + + err := smtp.SendMail(hostname+":25", auth, from, recipients, msg) + if err != nil { + log.Fatal(err) + } +} + +func ExampleSendMail() { + // Set up authentication information. + auth := smtp.PlainAuth("", "user@example.com", "password", "mail.example.com") + + // Connect to the server, authenticate, set the sender and recipient, + // and send the email all in one step. + to := []string{"recipient@example.net"} + msg := []byte("To: recipient@example.net\r\n" + + "Subject: discount Gophers!\r\n" + + "\r\n" + + "This is the email body.\r\n") + err := smtp.SendMail("mail.example.com:25", auth, "sender@example.org", to, msg) + if err != nil { + log.Fatal(err) + } +} diff --git a/smtp/smtp.go b/smtp/smtp.go new file mode 100644 index 0000000..1fef69c --- /dev/null +++ b/smtp/smtp.go @@ -0,0 +1,438 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package smtp implements the Simple Mail Transfer Protocol as defined in RFC 5321. +// It also implements the following extensions: +// +// 8BITMIME RFC 1652 +// AUTH RFC 2554 +// STARTTLS RFC 3207 +// +// Additional extensions may be handled by clients. +// +// The smtp package is frozen and is not accepting new features. +// Some external packages provide more functionality. See: +// +// https://godoc.org/?q=smtp +package smtp + +import ( + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "net/textproto" + "strings" +) + +// A Client represents a client connection to an SMTP server. +type Client struct { + // Text is the textproto.Conn used by the Client. It is exported to allow for + // clients to add extensions. + Text *textproto.Conn + // keep a reference to the connection so it can be used to create a TLS + // connection later + conn net.Conn + // whether the Client is using TLS + tls bool + serverName string + // map of supported extensions + ext map[string]string + // supported auth mechanisms + auth []string + localName string // the name to use in HELO/EHLO + didHello bool // whether we've said HELO/EHLO + helloError error // the error from the hello +} + +// Dial returns a new Client connected to an SMTP server at addr. +// The addr must include a port, as in "mail.example.com:smtp". +func Dial(addr string) (*Client, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + host, _, _ := net.SplitHostPort(addr) + return NewClient(conn, host) +} + +// NewClient returns a new Client using an existing connection and host as a +// server name to be used when authenticating. +func NewClient(conn net.Conn, host string) (*Client, error) { + text := textproto.NewConn(conn) + _, _, err := text.ReadResponse(220) + if err != nil { + if cerr := text.Close(); cerr != nil { + return nil, fmt.Errorf("%w, %s", err, cerr) + } + return nil, err + } + c := &Client{Text: text, conn: conn, serverName: host, localName: "localhost"} + _, c.tls = conn.(*tls.Conn) + return c, nil +} + +// Close closes the connection. +func (c *Client) Close() error { + return c.Text.Close() +} + +// hello runs a hello exchange if needed. +func (c *Client) hello() error { + if !c.didHello { + c.didHello = true + err := c.ehlo() + if err != nil { + c.helloError = c.helo() + } + } + return c.helloError +} + +// Hello sends a HELO or EHLO to the server as the given host name. +// Calling this method is only necessary if the client needs control +// over the host name used. The client will introduce itself as "localhost" +// automatically otherwise. If Hello is called, it must be called before +// any of the other methods. +func (c *Client) Hello(localName string) error { + if err := validateLine(localName); err != nil { + return err + } + if c.didHello { + return errors.New("smtp: Hello called after other methods") + } + c.localName = localName + return c.hello() +} + +// cmd is a convenience function that sends a command and returns the response +func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { + id, err := c.Text.Cmd(format, args...) + if err != nil { + return 0, "", err + } + c.Text.StartResponse(id) + defer c.Text.EndResponse(id) + code, msg, err := c.Text.ReadResponse(expectCode) + return code, msg, err +} + +// helo sends the HELO greeting to the server. It should be used only when the +// server does not support ehlo. +func (c *Client) helo() error { + c.ext = nil + _, _, err := c.cmd(250, "HELO %s", c.localName) + return err +} + +// ehlo sends the EHLO (extended hello) greeting to the server. It +// should be the preferred greeting for servers that support it. +func (c *Client) ehlo() error { + _, msg, err := c.cmd(250, "EHLO %s", c.localName) + if err != nil { + return err + } + ext := make(map[string]string) + extList := strings.Split(msg, "\n") + if len(extList) > 1 { + extList = extList[1:] + for _, line := range extList { + k, v, _ := strings.Cut(line, " ") + ext[k] = v + } + } + if mechs, ok := ext["AUTH"]; ok { + c.auth = strings.Split(mechs, " ") + } + c.ext = ext + return err +} + +// StartTLS sends the STARTTLS command and encrypts all further communication. +// Only servers that advertise the STARTTLS extension support this function. +func (c *Client) StartTLS(config *tls.Config) error { + if err := c.hello(); err != nil { + return err + } + _, _, err := c.cmd(220, "STARTTLS") + if err != nil { + return err + } + c.conn = tls.Client(c.conn, config) + c.Text = textproto.NewConn(c.conn) + c.tls = true + return c.ehlo() +} + +// TLSConnectionState returns the client's TLS connection state. +// The return values are their zero values if StartTLS did +// not succeed. +func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { + tc, ok := c.conn.(*tls.Conn) + if !ok { + return + } + return tc.ConnectionState(), true +} + +// Verify checks the validity of an email address on the server. +// If Verify returns nil, the address is valid. A non-nil return +// does not necessarily indicate an invalid address. Many servers +// will not verify addresses for security reasons. +func (c *Client) Verify(addr string) error { + if err := validateLine(addr); err != nil { + return err + } + if err := c.hello(); err != nil { + return err + } + _, _, err := c.cmd(250, "VRFY %s", addr) + return err +} + +// Auth authenticates a client using the provided authentication mechanism. +// A failed authentication closes the connection. +// Only servers that advertise the AUTH extension support this function. +func (c *Client) Auth(a Auth) error { + if err := c.hello(); err != nil { + return err + } + encoding := base64.StdEncoding + mech, resp, err := a.Start(&ServerInfo{c.serverName, c.tls, c.auth}) + if err != nil { + if qerr := c.Quit(); qerr != nil { + return fmt.Errorf("%w, %s", err, qerr) + } + return err + } + resp64 := make([]byte, encoding.EncodedLen(len(resp))) + encoding.Encode(resp64, resp) + code, msg64, err := c.cmd(0, strings.TrimSpace(fmt.Sprintf("AUTH %s %s", mech, resp64))) + for err == nil { + var msg []byte + switch code { + case 334: + msg, err = encoding.DecodeString(msg64) + case 235: + // the last message isn't base64 because it isn't a challenge + msg = []byte(msg64) + default: + err = &textproto.Error{Code: code, Msg: msg64} + } + if err == nil { + resp, err = a.Next(msg, code == 334) + } + if err != nil { + // abort the AUTH + _, _, _ = c.cmd(501, "*") + _ = c.Quit() + break + } + if resp == nil { + break + } + resp64 = make([]byte, encoding.EncodedLen(len(resp))) + encoding.Encode(resp64, resp) + code, msg64, err = c.cmd(0, string(resp64)) + } + return err +} + +// Mail issues a MAIL command to the server using the provided email address. +// If the server supports the 8BITMIME extension, Mail adds the BODY=8BITMIME +// parameter. If the server supports the SMTPUTF8 extension, Mail adds the +// SMTPUTF8 parameter. +// This initiates a mail transaction and is followed by one or more Rcpt calls. +func (c *Client) Mail(from string) error { + if err := validateLine(from); err != nil { + return err + } + if err := c.hello(); err != nil { + return err + } + cmdStr := "MAIL FROM:<%s>" + if c.ext != nil { + if _, ok := c.ext["8BITMIME"]; ok { + cmdStr += " BODY=8BITMIME" + } + if _, ok := c.ext["SMTPUTF8"]; ok { + cmdStr += " SMTPUTF8" + } + } + _, _, err := c.cmd(250, cmdStr, from) + return err +} + +// Rcpt issues a RCPT command to the server using the provided email address. +// A call to Rcpt must be preceded by a call to Mail and may be followed by +// a Data call or another Rcpt call. +func (c *Client) Rcpt(to string) error { + if err := validateLine(to); err != nil { + return err + } + _, _, err := c.cmd(25, "RCPT TO:<%s>", to) + return err +} + +type dataCloser struct { + c *Client + io.WriteCloser +} + +func (d *dataCloser) Close() error { + _ = d.WriteCloser.Close() + _, _, err := d.c.Text.ReadResponse(250) + return err +} + +// Data issues a DATA command to the server and returns a writer that +// can be used to write the mail headers and body. The caller should +// close the writer before calling any more methods on c. A call to +// Data must be preceded by one or more calls to Rcpt. +func (c *Client) Data() (io.WriteCloser, error) { + _, _, err := c.cmd(354, "DATA") + if err != nil { + return nil, err + } + return &dataCloser{c, c.Text.DotWriter()}, nil +} + +var testHookStartTLS func(*tls.Config) // nil, except for tests + +// SendMail connects to the server at addr, switches to TLS if +// possible, authenticates with the optional mechanism a if possible, +// and then sends an email from address from, to addresses to, with +// message msg. +// The addr must include a port, as in "mail.example.com:smtp". +// +// The addresses in the to parameter are the SMTP RCPT addresses. +// +// The msg parameter should be an RFC 822-style email with headers +// first, a blank line, and then the message body. The lines of msg +// should be CRLF terminated. The msg headers should usually include +// fields such as "From", "To", "Subject", and "Cc". Sending "Bcc" +// messages is accomplished by including an email address in the to +// parameter but not including it in the msg headers. +// +// The SendMail function and the net/smtp package are low-level +// mechanisms and provide no support for DKIM signing, MIME +// attachments (see the mime/multipart package), or other mail +// functionality. Higher-level packages exist outside of the standard +// library. +func SendMail(addr string, a Auth, from string, to []string, msg []byte) error { + if err := validateLine(from); err != nil { + return err + } + for _, recp := range to { + if err := validateLine(recp); err != nil { + return err + } + } + c, err := Dial(addr) + if err != nil { + return err + } + defer func() { + _ = c.Close() + }() + if err = c.hello(); err != nil { + return err + } + if ok, _ := c.Extension("STARTTLS"); ok { + config := &tls.Config{ServerName: c.serverName} + if testHookStartTLS != nil { + testHookStartTLS(config) + } + if err = c.StartTLS(config); err != nil { + return err + } + } + if a != nil && c.ext != nil { + if _, ok := c.ext["AUTH"]; !ok { + return errors.New("smtp: server doesn't support AUTH") + } + if err = c.Auth(a); err != nil { + return err + } + } + if err = c.Mail(from); err != nil { + return err + } + for _, addr := range to { + if err = c.Rcpt(addr); err != nil { + return err + } + } + w, err := c.Data() + if err != nil { + return err + } + _, err = w.Write(msg) + if err != nil { + return err + } + err = w.Close() + if err != nil { + return err + } + return c.Quit() +} + +// Extension reports whether an extension is support by the server. +// The extension name is case-insensitive. If the extension is supported, +// Extension also returns a string that contains any parameters the +// server specifies for the extension. +func (c *Client) Extension(ext string) (bool, string) { + if err := c.hello(); err != nil { + return false, "" + } + if c.ext == nil { + return false, "" + } + ext = strings.ToUpper(ext) + param, ok := c.ext[ext] + return ok, param +} + +// Reset sends the RSET command to the server, aborting the current mail +// transaction. +func (c *Client) Reset() error { + if err := c.hello(); err != nil { + return err + } + _, _, err := c.cmd(250, "RSET") + return err +} + +// Noop sends the NOOP command to the server. It does nothing but check +// that the connection to the server is okay. +func (c *Client) Noop() error { + if err := c.hello(); err != nil { + return err + } + _, _, err := c.cmd(250, "NOOP") + return err +} + +// Quit sends the QUIT command and closes the connection to the server. +func (c *Client) Quit() error { + if err := c.hello(); err != nil { + return err + } + _, _, err := c.cmd(221, "QUIT") + if err != nil { + return err + } + return c.Text.Close() +} + +// validateLine checks to see if a line has CR or LF as per RFC 5321. +func validateLine(line string) error { + if strings.ContainsAny(line, "\n\r") { + return errors.New("smtp: A line must not contain CR or LF") + } + return nil +} diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go new file mode 100644 index 0000000..90b5306 --- /dev/null +++ b/smtp/smtp_test.go @@ -0,0 +1,1194 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package smtp + +import ( + "bufio" + "bytes" + "crypto/tls" + "crypto/x509" + "fmt" + "io" + "net" + "net/textproto" + "strings" + "testing" + "time" +) + +type authTest struct { + auth Auth + challenges []string + name string + responses []string +} + +var authTests = []authTest{ + {PlainAuth("", "user", "pass", "testserver"), []string{}, "PLAIN", []string{"\x00user\x00pass"}}, + {PlainAuth("foo", "bar", "baz", "testserver"), []string{}, "PLAIN", []string{"foo\x00bar\x00baz"}}, + {CRAMMD5Auth("user", "pass"), []string{"<123456.1322876914@testserver>"}, "CRAM-MD5", []string{"", "user 287eb355114cf5c471c26a875f1ca4ae"}}, + { + LoginAuth("user", "pass", "testserver"), + []string{"Username:", "Password:", "Invalid:"}, + "LOGIN", + []string{"", "user", "pass", ""}, + }, +} + +func TestAuth(t *testing.T) { +testLoop: + for i, test := range authTests { + name, resp, err := test.auth.Start(&ServerInfo{"testserver", true, nil}) + if name != test.name { + t.Errorf("#%d got name %s, expected %s", i, name, test.name) + } + if !bytes.Equal(resp, []byte(test.responses[0])) { + t.Errorf("#%d got response %s, expected %s", i, resp, test.responses[0]) + } + if err != nil { + t.Errorf("#%d error: %s", i, err) + } + for j := range test.challenges { + challenge := []byte(test.challenges[j]) + expected := []byte(test.responses[j+1]) + resp, err := test.auth.Next(challenge, true) + if err != nil { + t.Errorf("#%d error: %s", i, err) + continue testLoop + } + if !bytes.Equal(resp, expected) { + t.Errorf("#%d got %s, expected %s", i, resp, expected) + continue testLoop + } + } + } +} + +func TestAuthPlain(t *testing.T) { + tests := []struct { + authName string + server *ServerInfo + err string + }{ + { + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + }, + { + // OK to use PlainAuth on localhost without TLS + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + }, + { + // NOT OK on non-localhost, even if server says PLAIN is OK. + // (We don't know that the server is the real server.) + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + err: "unencrypted connection", + }, + { + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + err: "unencrypted connection", + }, + { + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + err: "wrong host name", + }, + } + for i, tt := range tests { + auth := PlainAuth("foo", "bar", "baz", tt.authName) + _, _, err := auth.Start(tt.server) + got := "" + if err != nil { + got = err.Error() + } + if got != tt.err { + t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + } + } +} + +func TestAuthLogin(t *testing.T) { + tests := []struct { + authName string + server *ServerInfo + err string + }{ + { + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + }, + { + // OK to use LoginAuth on localhost without TLS + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + }, + { + // NOT OK on non-localhost, even if server says PLAIN is OK. + // (We don't know that the server is the real server.) + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + err: "unencrypted connection", + }, + { + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + err: "unencrypted connection", + }, + { + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + err: "wrong host name", + }, + } + for i, tt := range tests { + auth := LoginAuth("foo", "bar", tt.authName) + _, _, err := auth.Start(tt.server) + got := "" + if err != nil { + got = err.Error() + } + if got != tt.err { + t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + } + } +} + +// Issue 17794: don't send a trailing space on AUTH command when there's no password. +func TestClientAuthTrimSpace(t *testing.T) { + server := "220 hello world\r\n" + + "200 some more" + var wrote strings.Builder + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(server), + &wrote, + } + c, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + c.tls = true + c.didHello = true + c.Auth(toServerEmptyAuth{}) + c.Close() + if got, want := wrote.String(), "AUTH FOOAUTH\r\n*\r\nQUIT\r\n"; got != want { + t.Errorf("wrote %q; want %q", got, want) + } +} + +// toServerEmptyAuth is an implementation of Auth that only implements +// the Start method, and returns "FOOAUTH", nil, nil. Notably, it returns +// zero bytes for "toServer" so we can test that we don't send spaces at +// the end of the line. See TestClientAuthTrimSpace. +type toServerEmptyAuth struct{} + +func (toServerEmptyAuth) Start(server *ServerInfo) (proto string, toServer []byte, err error) { + return "FOOAUTH", nil, nil +} + +func (toServerEmptyAuth) Next(fromServer []byte, more bool) (toServer []byte, err error) { + panic("unexpected call") +} + +type faker struct { + io.ReadWriter +} + +func (f faker) Close() error { return nil } +func (f faker) LocalAddr() net.Addr { return nil } +func (f faker) RemoteAddr() net.Addr { return nil } +func (f faker) SetDeadline(time.Time) error { return nil } +func (f faker) SetReadDeadline(time.Time) error { return nil } +func (f faker) SetWriteDeadline(time.Time) error { return nil } + +func TestBasic(t *testing.T) { + server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + + var cmdbuf strings.Builder + bcmdbuf := bufio.NewWriter(&cmdbuf) + var fake faker + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) + c := &Client{Text: textproto.NewConn(fake), localName: "localhost"} + + if err := c.helo(); err != nil { + t.Fatalf("HELO failed: %s", err) + } + if err := c.ehlo(); err == nil { + t.Fatalf("Expected first EHLO to fail") + } + if err := c.ehlo(); err != nil { + t.Fatalf("Second EHLO failed: %s", err) + } + + c.didHello = true + if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" { + t.Fatalf("Expected AUTH supported") + } + if ok, _ := c.Extension("DSN"); ok { + t.Fatalf("Shouldn't support DSN") + } + + if err := c.Mail("user@gmail.com"); err == nil { + t.Fatalf("MAIL should require authentication") + } + + if err := c.Verify("user1@gmail.com"); err == nil { + t.Fatalf("First VRFY: expected no verification") + } + if err := c.Verify("user2@gmail.com>\r\nDATA\r\nAnother injected message body\r\n.\r\nQUIT\r\n"); err == nil { + t.Fatalf("VRFY should have failed due to a message injection attempt") + } + if err := c.Verify("user2@gmail.com"); err != nil { + t.Fatalf("Second VRFY: expected verification, got %s", err) + } + + // fake TLS so authentication won't complain + c.tls = true + c.serverName = "smtp.google.com" + if err := c.Auth(PlainAuth("", "user", "pass", "smtp.google.com")); err != nil { + t.Fatalf("AUTH failed: %s", err) + } + + if err := c.Rcpt("golang-nuts@googlegroups.com>\r\nDATA\r\nInjected message body\r\n.\r\nQUIT\r\n"); err == nil { + t.Fatalf("RCPT should have failed due to a message injection attempt") + } + if err := c.Mail("user@gmail.com>\r\nDATA\r\nAnother injected message body\r\n.\r\nQUIT\r\n"); err == nil { + t.Fatalf("MAIL should have failed due to a message injection attempt") + } + if err := c.Mail("user@gmail.com"); err != nil { + t.Fatalf("MAIL failed: %s", err) + } + if err := c.Rcpt("golang-nuts@googlegroups.com"); err != nil { + t.Fatalf("RCPT failed: %s", err) + } + msg := `From: user@gmail.com +To: golang-nuts@googlegroups.com +Subject: Hooray for Go + +Line 1 +.Leading dot line . +Goodbye.` + w, err := c.Data() + if err != nil { + t.Fatalf("DATA failed: %s", err) + } + if _, err := w.Write([]byte(msg)); err != nil { + t.Fatalf("Data write failed: %s", err) + } + if err := w.Close(); err != nil { + t.Fatalf("Bad data response: %s", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } +} + +var basicServer = `250 mx.google.com at your service +502 Unrecognized command. +250-mx.google.com at your service +250-SIZE 35651584 +250-AUTH LOGIN PLAIN +250 8BITMIME +530 Authentication required +252 Send some mail, I'll try my best +250 User is valid +235 Accepted +250 Sender OK +250 Receiver OK +354 Go ahead +250 Data OK +221 OK +` + +var basicClient = `HELO localhost +EHLO localhost +EHLO localhost +MAIL FROM: BODY=8BITMIME +VRFY user1@gmail.com +VRFY user2@gmail.com +AUTH PLAIN AHVzZXIAcGFzcw== +MAIL FROM: BODY=8BITMIME +RCPT TO: +DATA +From: user@gmail.com +To: golang-nuts@googlegroups.com +Subject: Hooray for Go + +Line 1 +..Leading dot line . +Goodbye. +. +QUIT +` + +func TestExtensions(t *testing.T) { + fake := func(server string) (c *Client, bcmdbuf *bufio.Writer, cmdbuf *strings.Builder) { + server = strings.Join(strings.Split(server, "\n"), "\r\n") + + cmdbuf = &strings.Builder{} + bcmdbuf = bufio.NewWriter(cmdbuf) + var fake faker + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) + c = &Client{Text: textproto.NewConn(fake), localName: "localhost"} + + return c, bcmdbuf, cmdbuf + } + + t.Run("helo", func(t *testing.T) { + const ( + basicServer = `250 mx.google.com at your service +250 Sender OK +221 Goodbye +` + + basicClient = `HELO localhost +MAIL FROM: +QUIT +` + ) + + c, bcmdbuf, cmdbuf := fake(basicServer) + + if err := c.helo(); err != nil { + t.Fatalf("HELO failed: %s", err) + } + c.didHello = true + if err := c.Mail("user@gmail.com"); err != nil { + t.Fatalf("MAIL FROM failed: %s", err) + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + }) + + t.Run("ehlo", func(t *testing.T) { + const ( + basicServer = `250-mx.google.com at your service +250 SIZE 35651584 +250 Sender OK +221 Goodbye +` + + basicClient = `EHLO localhost +MAIL FROM: +QUIT +` + ) + + c, bcmdbuf, cmdbuf := fake(basicServer) + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("EHLO failed: %s", err) + } + if ok, _ := c.Extension("8BITMIME"); ok { + t.Fatalf("Shouldn't support 8BITMIME") + } + if ok, _ := c.Extension("SMTPUTF8"); ok { + t.Fatalf("Shouldn't support SMTPUTF8") + } + if err := c.Mail("user@gmail.com"); err != nil { + t.Fatalf("MAIL FROM failed: %s", err) + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + }) + + t.Run("ehlo 8bitmime", func(t *testing.T) { + const ( + basicServer = `250-mx.google.com at your service +250-SIZE 35651584 +250 8BITMIME +250 Sender OK +221 Goodbye +` + + basicClient = `EHLO localhost +MAIL FROM: BODY=8BITMIME +QUIT +` + ) + + c, bcmdbuf, cmdbuf := fake(basicServer) + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("EHLO failed: %s", err) + } + if ok, _ := c.Extension("8BITMIME"); !ok { + t.Fatalf("Should support 8BITMIME") + } + if ok, _ := c.Extension("SMTPUTF8"); ok { + t.Fatalf("Shouldn't support SMTPUTF8") + } + if err := c.Mail("user@gmail.com"); err != nil { + t.Fatalf("MAIL FROM failed: %s", err) + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + }) + + t.Run("ehlo smtputf8", func(t *testing.T) { + const ( + basicServer = `250-mx.google.com at your service +250-SIZE 35651584 +250 SMTPUTF8 +250 Sender OK +221 Goodbye +` + + basicClient = `EHLO localhost +MAIL FROM: SMTPUTF8 +QUIT +` + ) + + c, bcmdbuf, cmdbuf := fake(basicServer) + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("EHLO failed: %s", err) + } + if ok, _ := c.Extension("8BITMIME"); ok { + t.Fatalf("Shouldn't support 8BITMIME") + } + if ok, _ := c.Extension("SMTPUTF8"); !ok { + t.Fatalf("Should support SMTPUTF8") + } + if err := c.Mail("user+📧@gmail.com"); err != nil { + t.Fatalf("MAIL FROM failed: %s", err) + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + }) + + t.Run("ehlo 8bitmime smtputf8", func(t *testing.T) { + const ( + basicServer = `250-mx.google.com at your service +250-SIZE 35651584 +250-8BITMIME +250 SMTPUTF8 +250 Sender OK +221 Goodbye + ` + + basicClient = `EHLO localhost +MAIL FROM: BODY=8BITMIME SMTPUTF8 +QUIT +` + ) + + c, bcmdbuf, cmdbuf := fake(basicServer) + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("EHLO failed: %s", err) + } + c.didHello = true + if ok, _ := c.Extension("8BITMIME"); !ok { + t.Fatalf("Should support 8BITMIME") + } + if ok, _ := c.Extension("SMTPUTF8"); !ok { + t.Fatalf("Should support SMTPUTF8") + } + if err := c.Mail("user+📧@gmail.com"); err != nil { + t.Fatalf("MAIL FROM failed: %s", err) + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + }) +} + +func TestNewClient(t *testing.T) { + server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") + client := strings.Join(strings.Split(newClientClient, "\n"), "\r\n") + + var cmdbuf strings.Builder + bcmdbuf := bufio.NewWriter(&cmdbuf) + out := func() string { + bcmdbuf.Flush() + return cmdbuf.String() + } + var fake faker + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) + c, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("NewClient: %v\n(after %v)", err, out()) + } + defer c.Close() + if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" { + t.Fatalf("Expected AUTH supported") + } + if ok, _ := c.Extension("DSN"); ok { + t.Fatalf("Shouldn't support DSN") + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + actualcmds := out() + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } +} + +var newClientServer = `220 hello world +250-mx.google.com at your service +250-SIZE 35651584 +250-AUTH LOGIN PLAIN +250 8BITMIME +221 OK +` + +var newClientClient = `EHLO localhost +QUIT +` + +func TestNewClient2(t *testing.T) { + server := strings.Join(strings.Split(newClient2Server, "\n"), "\r\n") + client := strings.Join(strings.Split(newClient2Client, "\n"), "\r\n") + + var cmdbuf strings.Builder + bcmdbuf := bufio.NewWriter(&cmdbuf) + var fake faker + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) + c, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + defer c.Close() + if ok, _ := c.Extension("DSN"); ok { + t.Fatalf("Shouldn't support DSN") + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } +} + +var newClient2Server = `220 hello world +502 EH? +250-mx.google.com at your service +250-SIZE 35651584 +250-AUTH LOGIN PLAIN +250 8BITMIME +221 OK +` + +var newClient2Client = `EHLO localhost +HELO localhost +QUIT +` + +func TestNewClientWithTLS(t *testing.T) { + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Fatalf("loadcert: %v", err) + } + + config := tls.Config{Certificates: []tls.Certificate{cert}} + + ln, err := tls.Listen("tcp", "127.0.0.1:0", &config) + if err != nil { + ln, err = tls.Listen("tcp", "[::1]:0", &config) + if err != nil { + t.Fatalf("server: listen: %v", err) + } + } + + go func() { + conn, err := ln.Accept() + if err != nil { + t.Errorf("server: accept: %v", err) + return + } + defer conn.Close() + + _, err = conn.Write([]byte("220 SIGNS\r\n")) + if err != nil { + t.Errorf("server: write: %v", err) + return + } + }() + + config.InsecureSkipVerify = true + conn, err := tls.Dial("tcp", ln.Addr().String(), &config) + if err != nil { + t.Fatalf("client: dial: %v", err) + } + defer conn.Close() + + client, err := NewClient(conn, ln.Addr().String()) + if err != nil { + t.Fatalf("smtp: newclient: %v", err) + } + if !client.tls { + t.Errorf("client.tls Got: %t Expected: %t", client.tls, true) + } +} + +func TestHello(t *testing.T) { + if len(helloServer) != len(helloClient) { + t.Fatalf("Hello server and client size mismatch") + } + + for i := 0; i < len(helloServer); i++ { + server := strings.Join(strings.Split(baseHelloServer+helloServer[i], "\n"), "\r\n") + client := strings.Join(strings.Split(baseHelloClient+helloClient[i], "\n"), "\r\n") + var cmdbuf strings.Builder + bcmdbuf := bufio.NewWriter(&cmdbuf) + var fake faker + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) + c, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + defer c.Close() + c.localName = "customhost" + err = nil + + switch i { + case 0: + err = c.Hello("hostinjection>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n") + if err == nil { + t.Errorf("Expected Hello to be rejected due to a message injection attempt") + } + err = c.Hello("customhost") + case 1: + err = c.StartTLS(nil) + if err.Error() == "502 Not implemented" { + err = nil + } + case 2: + err = c.Verify("test@example.com") + case 3: + c.tls = true + c.serverName = "smtp.google.com" + err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com")) + case 4: + err = c.Mail("test@example.com") + case 5: + ok, _ := c.Extension("feature") + if ok { + t.Errorf("Expected FEATURE not to be supported") + } + case 6: + err = c.Reset() + case 7: + err = c.Quit() + case 8: + err = c.Verify("test@example.com") + if err != nil { + err = c.Hello("customhost") + if err != nil { + t.Errorf("Want error, got none") + } + } + case 9: + err = c.Noop() + default: + t.Fatalf("Unhandled command") + } + + if err != nil { + t.Errorf("Command %d failed: %v", i, err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + if client != actualcmds { + t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + } +} + +var baseHelloServer = `220 hello world +502 EH? +250-mx.google.com at your service +250 FEATURE +` + +var helloServer = []string{ + "", + "502 Not implemented\n", + "250 User is valid\n", + "235 Accepted\n", + "250 Sender ok\n", + "", + "250 Reset ok\n", + "221 Goodbye\n", + "250 Sender ok\n", + "250 ok\n", +} + +var baseHelloClient = `EHLO customhost +HELO customhost +` + +var helloClient = []string{ + "", + "STARTTLS\n", + "VRFY test@example.com\n", + "AUTH PLAIN AHVzZXIAcGFzcw==\n", + "MAIL FROM:\n", + "", + "RSET\n", + "QUIT\n", + "VRFY test@example.com\n", + "NOOP\n", +} + +func TestSendMail(t *testing.T) { + server := strings.Join(strings.Split(sendMailServer, "\n"), "\r\n") + client := strings.Join(strings.Split(sendMailClient, "\n"), "\r\n") + var cmdbuf strings.Builder + bcmdbuf := bufio.NewWriter(&cmdbuf) + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to create listener: %v", err) + } + defer l.Close() + + // prevent data race on bcmdbuf + done := make(chan struct{}) + go func(data []string) { + defer close(done) + + conn, err := l.Accept() + if err != nil { + t.Errorf("Accept error: %v", err) + return + } + defer conn.Close() + + tc := textproto.NewConn(conn) + for i := 0; i < len(data) && data[i] != ""; i++ { + tc.PrintfLine(data[i]) + for len(data[i]) >= 4 && data[i][3] == '-' { + i++ + tc.PrintfLine(data[i]) + } + if data[i] == "221 Goodbye" { + return + } + read := false + for !read || data[i] == "354 Go ahead" { + msg, err := tc.ReadLine() + bcmdbuf.Write([]byte(msg + "\r\n")) + read = true + if err != nil { + t.Errorf("Read error: %v", err) + return + } + if data[i] == "354 Go ahead" && msg == "." { + break + } + } + } + }(strings.Split(server, "\r\n")) + + err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n"}, []byte(strings.Replace(`From: test@example.com +To: other@example.com +Subject: SendMail test + +SendMail is working for me. +`, "\n", "\r\n", -1))) + if err == nil { + t.Errorf("Expected SendMail to be rejected due to a message injection attempt") + } + + err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com +To: other@example.com +Subject: SendMail test + +SendMail is working for me. +`, "\n", "\r\n", -1))) + + if err != nil { + t.Errorf("%v", err) + } + + <-done + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + if client != actualcmds { + t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } +} + +var sendMailServer = `220 hello world +502 EH? +250 mx.google.com at your service +250 Sender ok +250 Receiver ok +354 Go ahead +250 Data ok +221 Goodbye +` + +var sendMailClient = `EHLO localhost +HELO localhost +MAIL FROM: +RCPT TO: +DATA +From: test@example.com +To: other@example.com +Subject: SendMail test + +SendMail is working for me. +. +QUIT +` + +func TestSendMailWithAuth(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Unable to create listener: %v", err) + } + defer l.Close() + + errCh := make(chan error) + go func() { + defer close(errCh) + conn, err := l.Accept() + if err != nil { + errCh <- fmt.Errorf("Accept: %v", err) + return + } + defer conn.Close() + + tc := textproto.NewConn(conn) + tc.PrintfLine("220 hello world") + msg, err := tc.ReadLine() + if err != nil { + errCh <- fmt.Errorf("ReadLine error: %v", err) + return + } + const wantMsg = "EHLO localhost" + if msg != wantMsg { + errCh <- fmt.Errorf("unexpected response %q; want %q", msg, wantMsg) + return + } + err = tc.PrintfLine("250 mx.google.com at your service") + if err != nil { + errCh <- fmt.Errorf("PrintfLine: %v", err) + return + } + }() + + err = SendMail(l.Addr().String(), PlainAuth("", "user", "pass", "smtp.google.com"), "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com +To: other@example.com +Subject: SendMail test + +SendMail is working for me. +`, "\n", "\r\n", -1))) + if err == nil { + t.Error("SendMail: Server doesn't support AUTH, expected to get an error, but got none ") + } + if err.Error() != "smtp: server doesn't support AUTH" { + t.Errorf("Expected: smtp: server doesn't support AUTH, got: %s", err) + } + err = <-errCh + if err != nil { + t.Fatalf("server error: %v", err) + } +} + +func TestAuthFailed(t *testing.T) { + server := strings.Join(strings.Split(authFailedServer, "\n"), "\r\n") + client := strings.Join(strings.Split(authFailedClient, "\n"), "\r\n") + var cmdbuf strings.Builder + bcmdbuf := bufio.NewWriter(&cmdbuf) + var fake faker + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) + c, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + defer c.Close() + + c.tls = true + c.serverName = "smtp.google.com" + err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com")) + + if err == nil { + t.Error("Auth: expected error; got none") + } else if err.Error() != "535 Invalid credentials\nplease see www.example.com" { + t.Errorf("Auth: got error: %v, want: %s", err, "535 Invalid credentials\nplease see www.example.com") + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + if client != actualcmds { + t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } +} + +var authFailedServer = `220 hello world +250-mx.google.com at your service +250 AUTH LOGIN PLAIN +535-Invalid credentials +535 please see www.example.com +221 Goodbye +` + +var authFailedClient = `EHLO localhost +AUTH PLAIN AHVzZXIAcGFzcw== +* +QUIT +` + +func TestTLSClient(t *testing.T) { + /* + TODO: Check if we need this + if runtime.GOOS == "freebsd" || runtime.GOOS == "js" { + testenv.SkipFlaky(t, 19229) + } + */ + ln := newLocalListener(t) + defer ln.Close() + errc := make(chan error) + go func() { + errc <- sendMail(ln.Addr().String()) + }() + conn, err := ln.Accept() + if err != nil { + t.Fatalf("failed to accept connection: %v", err) + } + defer conn.Close() + if err := serverHandle(conn, t); err != nil { + t.Fatalf("failed to handle connection: %v", err) + } + if err := <-errc; err != nil { + t.Fatalf("client error: %v", err) + } +} + +func TestTLSConnState(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + clientDone := make(chan bool) + serverDone := make(chan bool) + go func() { + defer close(serverDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Server accept: %v", err) + return + } + defer c.Close() + if err := serverHandle(c, t); err != nil { + t.Errorf("server error: %v", err) + } + }() + go func() { + defer close(clientDone) + c, err := Dial(ln.Addr().String()) + if err != nil { + t.Errorf("Client dial: %v", err) + return + } + defer c.Quit() + cfg := &tls.Config{ServerName: "example.com"} + testHookStartTLS(cfg) // set the RootCAs + if err := c.StartTLS(cfg); err != nil { + t.Errorf("StartTLS: %v", err) + return + } + cs, ok := c.TLSConnectionState() + if !ok { + t.Errorf("TLSConnectionState returned ok == false; want true") + return + } + if cs.Version == 0 || !cs.HandshakeComplete { + t.Errorf("ConnectionState = %#v; expect non-zero Version and HandshakeComplete", cs) + } + }() + <-clientDone + <-serverDone +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} + +type smtpSender struct { + w io.Writer +} + +func (s smtpSender) send(f string) { + s.w.Write([]byte(f + "\r\n")) +} + +// smtp server, finely tailored to deal with our own client only! +func serverHandle(c net.Conn, t *testing.T) error { + send := smtpSender{c}.send + send("220 127.0.0.1 ESMTP service ready") + s := bufio.NewScanner(c) + for s.Scan() { + switch s.Text() { + case "EHLO localhost": + send("250-127.0.0.1 ESMTP offers a warm hug of welcome") + send("250-STARTTLS") + send("250 Ok") + case "STARTTLS": + send("220 Go ahead") + keypair, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + return err + } + config := &tls.Config{Certificates: []tls.Certificate{keypair}} + c = tls.Server(c, config) + defer c.Close() + return serverHandleTLS(c, t) + default: + t.Fatalf("unrecognized command: %q", s.Text()) + } + } + return s.Err() +} + +func serverHandleTLS(c net.Conn, t *testing.T) error { + send := smtpSender{c}.send + s := bufio.NewScanner(c) + for s.Scan() { + switch s.Text() { + case "EHLO localhost": + send("250 Ok") + case "MAIL FROM:": + send("250 Ok") + case "RCPT TO:": + send("250 Ok") + case "DATA": + send("354 send the mail data, end with .") + send("250 Ok") + case "Subject: test": + case "": + case "howdy!": + case ".": + case "QUIT": + send("221 127.0.0.1 Service closing transmission channel") + return nil + default: + t.Fatalf("unrecognized command during TLS: %q", s.Text()) + } + } + return s.Err() +} + +func init() { + testRootCAs := x509.NewCertPool() + testRootCAs.AppendCertsFromPEM(localhostCert) + testHookStartTLS = func(config *tls.Config) { + config.RootCAs = testRootCAs + } +} + +func sendMail(hostPort string) error { + from := "joe1@example.com" + to := []string{"joe2@example.com"} + return SendMail(hostPort, nil, from, to, []byte("Subject: test\n\nhowdy!")) +} + +// localhostCert is a PEM-encoded TLS cert generated from src/crypto/tls: +// +// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com \ +// --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var localhostCert = []byte(` +-----BEGIN CERTIFICATE----- +MIICFDCCAX2gAwIBAgIRAK0xjnaPuNDSreeXb+z+0u4wDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2 +MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw +gYkCgYEA0nFbQQuOWsjbGtejcpWz153OlziZM4bVjJ9jYruNw5n2Ry6uYQAffhqa +JOInCmmcVe2siJglsyH9aRh6vKiobBbIUXXUU1ABd56ebAzlt0LobLlx7pZEMy30 +LqIi9E6zmL3YvdGzpYlkFRnRrqwEtWYbGBf3znO250S56CCWH2UCAwEAAaNoMGYw +DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF +MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA +AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAbZtDS2dVuBYvb+MnolWnCNqvw1w5Gtgi +NmvQQPOMgM3m+oQSCPRTNGSg25e1Qbo7bgQDv8ZTnq8FgOJ/rbkyERw2JckkHpD4 +n4qcK27WkEDBtQFlPihIM8hLIuzWoi/9wygiElTy/tVL3y7fGCvY2/k1KBthtZGF +tN8URjVmyEo= +-----END CERTIFICATE-----`) + +// localhostKey is the private key for localhostCert. +var localhostKey = []byte(testingKey(` +-----BEGIN RSA TESTING KEY----- +MIICXgIBAAKBgQDScVtBC45ayNsa16NylbPXnc6XOJkzhtWMn2Niu43DmfZHLq5h +AB9+Gpok4icKaZxV7ayImCWzIf1pGHq8qKhsFshRddRTUAF3np5sDOW3QuhsuXHu +lkQzLfQuoiL0TrOYvdi90bOliWQVGdGurAS1ZhsYF/fOc7bnRLnoIJYfZQIDAQAB +AoGBAMst7OgpKyFV6c3JwyI/jWqxDySL3caU+RuTTBaodKAUx2ZEmNJIlx9eudLA +kucHvoxsM/eRxlxkhdFxdBcwU6J+zqooTnhu/FE3jhrT1lPrbhfGhyKnUrB0KKMM +VY3IQZyiehpxaeXAwoAou6TbWoTpl9t8ImAqAMY8hlULCUqlAkEA+9+Ry5FSYK/m +542LujIcCaIGoG1/Te6Sxr3hsPagKC2rH20rDLqXwEedSFOpSS0vpzlPAzy/6Rbb +PHTJUhNdwwJBANXkA+TkMdbJI5do9/mn//U0LfrCR9NkcoYohxfKz8JuhgRQxzF2 +6jpo3q7CdTuuRixLWVfeJzcrAyNrVcBq87cCQFkTCtOMNC7fZnCTPUv+9q1tcJyB +vNjJu3yvoEZeIeuzouX9TJE21/33FaeDdsXbRhQEj23cqR38qFHsF1qAYNMCQQDP +QXLEiJoClkR2orAmqjPLVhR3t2oB3INcnEjLNSq8LHyQEfXyaFfu4U9l5+fRPL2i +jiC0k/9L5dHUsF0XZothAkEA23ddgRs+Id/HxtojqqUT27B8MT/IGNrYsp4DvS/c +qgkeluku4GjxRlDMBuXk94xOBEinUs+p/hwP1Alll80Tpg== +-----END RSA TESTING KEY-----`)) + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }