diff --git a/client.go b/client.go index fac9a34..6557913 100644 --- a/client.go +++ b/client.go @@ -12,6 +12,7 @@ import ( "net" "os" "strings" + "sync" "time" "github.com/wneessen/go-mail/log" @@ -87,12 +88,12 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con // Client is the SMTP client struct type Client struct { - // connection is the net.Conn that the smtp.Client is based on - connection net.Conn - // Timeout for the SMTP server connection connTimeout time.Duration + // dialContextFunc is a custom DialContext function to dial target SMTP server + dialContextFunc DialContextFunc + // dsn indicates that we want to use DSN for the Client dsn bool @@ -102,11 +103,9 @@ type Client struct { // dsnrntype defines the DSNRcptNotifyOption in case DSN is enabled dsnrntype []string - // isEncrypted indicates if a Client connection is encrypted or not - isEncrypted bool - - // noNoop indicates the Noop is to be skipped - noNoop bool + // fallbackPort is used as an alternative port number in case the primary port is unavailable or + // fails to bind. + fallbackPort int // HELO/EHLO string for the greeting the target SMTP server helo string @@ -114,12 +113,24 @@ type Client struct { // Hostname of the target SMTP server to connect to host string + // isEncrypted indicates if a Client connection is encrypted or not + isEncrypted bool + + // logger is a logger that implements the log.Logger interface + logger log.Logger + + // mutex is used to synchronize access to shared resources, ensuring that only one goroutine can + // modify them at a time. + mutex sync.RWMutex + + // noNoop indicates the Noop is to be skipped + noNoop bool + // pass is the corresponding SMTP AUTH password pass string - // Port of the SMTP server to connect to - port int - fallbackPort int + // port specifies the network port number on which the server listens for incoming connections. + port int // smtpAuth is a pointer to smtp.Auth smtpAuth smtp.Auth @@ -130,26 +141,20 @@ type Client struct { // smtpClient is the smtp.Client that is set up when using the Dial*() methods smtpClient *smtp.Client - // Use SSL for the connection - useSSL bool - // tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol tlspolicy TLSPolicy // tlsconfig represents the tls.Config setting for the STARTTLS connection tlsconfig *tls.Config - // user is the SMTP AUTH username - user string - // useDebugLog enables the debug logging on the SMTP client useDebugLog bool - // logger is a logger that implements the log.Logger interface - logger log.Logger + // user is the SMTP AUTH username + user string - // dialContextFunc is a custom DialContext function to dial target SMTP server - dialContextFunc DialContextFunc + // Use SSL for the connection + useSSL bool } // Option returns a function that can be used for grouping Client options @@ -550,6 +555,9 @@ func (c *Client) SetLogger(logger log.Logger) { // SetTLSConfig overrides the current *tls.Config with the given *tls.Config value func (c *Client) SetTLSConfig(tlsconfig *tls.Config) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if tlsconfig == nil { return ErrInvalidTLSConfig } @@ -589,6 +597,9 @@ func (c *Client) setDefaultHelo() error { // DialWithContext establishes a connection to the SMTP server with a given context.Context func (c *Client) DialWithContext(dialCtx context.Context) error { + c.mutex.Lock() + defer c.mutex.Unlock() + ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) defer cancel() @@ -602,17 +613,16 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { c.dialContextFunc = tlsDialer.DialContext } } - var err error - c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr()) + connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) if err != nil && c.fallbackPort != 0 { // TODO: should we somehow log or append the previous error? - c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) + connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) } if err != nil { return err } - client, err := smtp.NewClient(c.connection, c.host) + client, err := smtp.NewClient(connection, c.host) if err != nil { return err } @@ -691,7 +701,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e // checkConn makes sure that a required server connection is available and extends the // connection deadline func (c *Client) checkConn() error { - if c.connection == nil { + if !c.smtpClient.HasConnection() { return ErrNoActiveConnection } @@ -701,7 +711,7 @@ func (c *Client) checkConn() error { } } - if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil { + if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil { return ErrDeadlineExtendFailed } return nil @@ -715,7 +725,7 @@ func (c *Client) serverFallbackAddr() string { // tls tries to make sure that the STARTTLS requirements are satisfied func (c *Client) tls() error { - if c.connection == nil { + if !c.smtpClient.HasConnection() { return ErrNoActiveConnection } if !c.useSSL && c.tlspolicy != NoTLS { @@ -791,6 +801,9 @@ func (c *Client) auth() error { // sendSingleMsg sends out a single message and returns an error if the transmission/delivery fails. // It is invoked by the public Send methods func (c *Client) sendSingleMsg(message *Msg) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if message.encoding == NoEncoding { if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok { return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message} diff --git a/client_test.go b/client_test.go index 3d50a2d..2d37ce0 100644 --- a/client_test.go +++ b/client_test.go @@ -15,6 +15,7 @@ import ( "os" "strconv" "strings" + "sync" "testing" "time" @@ -623,11 +624,12 @@ func TestClient_DialWithContext(t *testing.T) { t.Errorf("failed to dial with context: %s", err) return } - if c.connection == nil { - t.Errorf("DialWithContext didn't fail but no connection found.") - } if c.smtpClient == nil { t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return + } + if !c.smtpClient.HasConnection() { + t.Errorf("DialWithContext didn't fail but no connection found.") } if err := c.Close(); err != nil { t.Errorf("failed to close connection: %s", err) @@ -644,17 +646,18 @@ func TestClient_DialWithContext_Fallback(t *testing.T) { c.SetTLSPortPolicy(TLSOpportunistic) c.port = 999 ctx := context.Background() - if err := c.DialWithContext(ctx); err != nil { + if err = c.DialWithContext(ctx); err != nil { t.Errorf("failed to dial with context: %s", err) return } - if c.connection == nil { - t.Errorf("DialWithContext didn't fail but no connection found.") - } if c.smtpClient == nil { t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return } - if err := c.Close(); err != nil { + if !c.smtpClient.HasConnection() { + t.Errorf("DialWithContext didn't fail but no connection found.") + } + if err = c.Close(); err != nil { t.Errorf("failed to close connection: %s", err) } @@ -674,18 +677,19 @@ func TestClient_DialWithContext_Debug(t *testing.T) { t.Skipf("failed to create test client: %s. Skipping tests", err) } ctx := context.Background() - if err := c.DialWithContext(ctx); err != nil { + if err = c.DialWithContext(ctx); err != nil { t.Errorf("failed to dial with context: %s", err) return } - if c.connection == nil { - t.Errorf("DialWithContext didn't fail but no connection found.") - } if c.smtpClient == nil { t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return + } + if !c.smtpClient.HasConnection() { + t.Errorf("DialWithContext didn't fail but no connection found.") } c.SetDebugLog(true) - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Errorf("failed to close connection: %s", err) } } @@ -698,19 +702,20 @@ func TestClient_DialWithContext_Debug_custom(t *testing.T) { t.Skipf("failed to create test client: %s. Skipping tests", err) } ctx := context.Background() - if err := c.DialWithContext(ctx); err != nil { + if err = c.DialWithContext(ctx); err != nil { t.Errorf("failed to dial with context: %s", err) return } - if c.connection == nil { - t.Errorf("DialWithContext didn't fail but no connection found.") - } if c.smtpClient == nil { t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return + } + if !c.smtpClient.HasConnection() { + t.Errorf("DialWithContext didn't fail but no connection found.") } c.SetDebugLog(true) c.SetLogger(log.New(os.Stderr, log.LevelDebug)) - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Errorf("failed to close connection: %s", err) } } @@ -722,10 +727,9 @@ func TestClient_DialWithContextInvalidHost(t *testing.T) { if err != nil { t.Skipf("failed to create test client: %s. Skipping tests", err) } - c.connection = nil c.host = "invalid.addr" ctx := context.Background() - if err := c.DialWithContext(ctx); err == nil { + if err = c.DialWithContext(ctx); err == nil { t.Errorf("dial succeeded but was supposed to fail") return } @@ -738,10 +742,9 @@ func TestClient_DialWithContextInvalidHELO(t *testing.T) { if err != nil { t.Skipf("failed to create test client: %s. Skipping tests", err) } - c.connection = nil c.helo = "" ctx := context.Background() - if err := c.DialWithContext(ctx); err == nil { + if err = c.DialWithContext(ctx); err == nil { t.Errorf("dial succeeded but was supposed to fail") return } @@ -758,7 +761,7 @@ func TestClient_DialWithContextInvalidAuth(t *testing.T) { c.pass = "invalid" c.SetSMTPAuthCustom(smtp.LoginAuth("invalid", "invalid", "invalid")) ctx := context.Background() - if err := c.DialWithContext(ctx); err == nil { + if err = c.DialWithContext(ctx); err == nil { t.Errorf("dial succeeded but was supposed to fail") return } @@ -770,8 +773,7 @@ func TestClient_checkConn(t *testing.T) { if err != nil { t.Skipf("failed to create test client: %s. Skipping tests", err) } - c.connection = nil - if err := c.checkConn(); err == nil { + if err = c.checkConn(); err == nil { t.Errorf("connCheck() should fail but succeeded") } } @@ -802,21 +804,23 @@ func TestClient_DialWithContextOptions(t *testing.T) { } ctx := context.Background() - if err := c.DialWithContext(ctx); err != nil && !tt.sf { + if err = c.DialWithContext(ctx); err != nil && !tt.sf { t.Errorf("failed to dial with context: %s", err) return } if !tt.sf { - if c.connection == nil && !tt.sf { - t.Errorf("DialWithContext didn't fail but no connection found.") - } if c.smtpClient == nil && !tt.sf { t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return } - if err := c.Reset(); err != nil { + if !c.smtpClient.HasConnection() && !tt.sf { + t.Errorf("DialWithContext didn't fail but no connection found.") + return + } + if err = c.Reset(); err != nil { t.Errorf("failed to reset connection: %s", err) } - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Errorf("failed to close connection: %s", err) } } @@ -1011,17 +1015,15 @@ func TestClient_DialSendCloseBroken(t *testing.T) { } if tt.closestart { _ = c.smtpClient.Close() - _ = c.connection.Close() } - if err := c.Send(m); err != nil && !tt.sf { + if err = c.Send(m); err != nil && !tt.sf { t.Errorf("Send() failed: %s", err) return } if tt.closeearly { _ = c.smtpClient.Close() - _ = c.connection.Close() } - if err := c.Close(); err != nil && !tt.sf { + if err = c.Close(); err != nil && !tt.sf { t.Errorf("Close() failed: %s", err) return } @@ -1071,17 +1073,15 @@ func TestClient_DialSendCloseBrokenWithDSN(t *testing.T) { } if tt.closestart { _ = c.smtpClient.Close() - _ = c.connection.Close() } - if err := c.Send(m); err != nil && !tt.sf { + if err = c.Send(m); err != nil && !tt.sf { t.Errorf("Send() failed: %s", err) return } if tt.closeearly { _ = c.smtpClient.Close() - _ = c.connection.Close() } - if err := c.Close(); err != nil && !tt.sf { + if err = c.Close(); err != nil && !tt.sf { t.Errorf("Close() failed: %s", err) return } @@ -1728,6 +1728,114 @@ func TestClient_SendErrorReset(t *testing.T) { } } +func TestClient_DialSendConcurrent_online(t *testing.T) { + if os.Getenv("TEST_ALLOW_SEND") == "" { + t.Skipf("TEST_ALLOW_SEND is not set. Skipping mail sending test") + } + + client, err := getTestConnection(true) + if err != nil { + t.Skipf("failed to create test client: %s. Skipping tests", err) + } + + var messages []*Msg + for i := 0; i < 10; i++ { + message := NewMsg() + if err := message.FromFormat("go-mail Test Mailer", os.Getenv("TEST_FROM")); err != nil { + t.Errorf("failed to set FROM address: %s", err) + return + } + if err := message.To(TestRcpt); err != nil { + t.Errorf("failed to set TO address: %s", err) + return + } + message.Subject(fmt.Sprintf("Test subject for mail %d", i)) + message.SetBodyString(TypeTextPlain, fmt.Sprintf("This is the test body of the mail no. %d", i)) + message.SetMessageID() + messages = append(messages, message) + } + + if err = client.DialWithContext(context.Background()); err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + + wg := sync.WaitGroup{} + for id, message := range messages { + wg.Add(1) + go func(curMsg *Msg, curID int) { + defer wg.Done() + if goroutineErr := client.Send(curMsg); err != nil { + t.Errorf("failed to send message with ID %d: %s", curID, goroutineErr) + } + }(message, id) + } + wg.Wait() + + if err = client.Close(); err != nil { + t.Errorf("failed to close server connection: %s", err) + } +} + +func TestClient_DialSendConcurrent_local(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 20 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 500) + + client, err := NewClient(TestServerAddr, WithPort(serverPort), + WithTLSPortPolicy(NoTLS), WithSMTPAuth(SMTPAuthPlain), + WithUsername("toni@tester.com"), + WithPassword("V3ryS3cr3t+")) + if err != nil { + t.Errorf("unable to create new client: %s", err) + } + + var messages []*Msg + for i := 0; i < 20; i++ { + message := NewMsg() + if err := message.From("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set FROM address: %s", err) + return + } + if err := message.To("valid-to@domain.tld"); err != nil { + t.Errorf("failed to set TO address: %s", err) + return + } + message.Subject("Test subject") + message.SetBodyString(TypeTextPlain, "Test body") + message.SetMessageIDWithValue("this.is.a.message.id") + messages = append(messages, message) + } + + if err = client.DialWithContext(context.Background()); err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + + wg := sync.WaitGroup{} + for id, message := range messages { + wg.Add(1) + go func(curMsg *Msg, curID int) { + defer wg.Done() + if goroutineErr := client.Send(curMsg); err != nil { + t.Errorf("failed to send message with ID %d: %s", curID, goroutineErr) + } + }(message, id) + } + wg.Wait() + + if err = client.Close(); err != nil { + t.Errorf("failed to close server connection: %s", err) + } +} + // getTestConnection takes environment variables to establish a connection to a real // SMTP server to test all functionality that requires a connection func getTestConnection(auth bool) (*Client, error) { @@ -1913,6 +2021,72 @@ func getTestConnectionWithDSN(auth bool) (*Client, error) { } func TestXOAuth2OK(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 30 + featureSet := "250-AUTH XOAUTH2\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 500) + + c, err := NewClient("127.0.0.1", + WithPort(serverPort), + WithTLSPortPolicy(TLSOpportunistic), + WithSMTPAuth(SMTPAuthXOAUTH2), + WithUsername("user"), + WithPassword("token")) + if err != nil { + t.Fatalf("unable to create new client: %v", err) + } + if err = c.DialWithContext(context.Background()); err != nil { + t.Fatalf("unexpected dial error: %v", err) + } + if err = c.Close(); err != nil { + t.Fatalf("disconnect from test server failed: %v", err) + } +} + +func TestXOAuth2Unsupported(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 31 + featureSet := "250-AUTH LOGIN PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 500) + + c, err := NewClient("127.0.0.1", + WithPort(serverPort), + WithTLSPolicy(TLSOpportunistic), + WithSMTPAuth(SMTPAuthXOAUTH2), + WithUsername("user"), + WithPassword("token")) + if err != nil { + t.Fatalf("unable to create new client: %v", err) + } + if err = c.DialWithContext(context.Background()); err == nil { + t.Fatal("expected dial error got nil") + } else { + if !errors.Is(err, ErrXOauth2AuthNotSupported) { + t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) + } + } + if err = c.Close(); err != nil { + t.Fatalf("disconnect from test server failed: %v", err) + } +} + +func TestXOAuth2OK_faker(t *testing.T) { server := []string{ "220 Fake server ready ESMTP", "250-fake.server", @@ -1952,7 +2126,7 @@ func TestXOAuth2OK(t *testing.T) { } } -func TestXOAuth2Unsupported(t *testing.T) { +func TestXOAuth2Unsupported_faker(t *testing.T) { server := []string{ "220 Fake server ready ESMTP", "250-fake.server", @@ -2085,7 +2259,6 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese data, err := reader.ReadString('\n') if err != nil { - fmt.Printf("unable to read from connection: %s\n", err) return } if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") { @@ -2093,19 +2266,15 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese return } if err = writeLine("250-localhost.localdomain\r\n" + featureSet); err != nil { - fmt.Printf("unable to write to connection: %s\n", err) return } for { data, err = reader.ReadString('\n') if err != nil { - if errors.Is(err, io.EOF) { - break - } - fmt.Println("Error reading data:", err) break } + time.Sleep(time.Millisecond) var datastring string data = strings.TrimSpace(data) @@ -2128,6 +2297,13 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese break } writeOK() + case strings.HasPrefix(data, "AUTH XOAUTH2"): + auth := strings.TrimPrefix(data, "AUTH XOAUTH2 ") + if !strings.EqualFold(auth, "dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=") { + _ = writeLine("535 5.7.8 Error: authentication failed") + break + } + _ = writeLine("235 2.7.0 Authentication successful") case strings.HasPrefix(data, "AUTH PLAIN"): auth := strings.TrimPrefix(data, "AUTH PLAIN ") if !strings.EqualFold(auth, "AHRvbmlAdGVzdGVyLmNvbQBWM3J5UzNjcjN0Kw==") { diff --git a/random_119.go b/random_119.go index b084305..4b45c55 100644 --- a/random_119.go +++ b/random_119.go @@ -2,8 +2,8 @@ // // SPDX-License-Identifier: MIT -//go:build go1.19 && !go1.20 -// +build go1.19,!go1.20 +//go:build !go1.20 +// +build !go1.20 package mail diff --git a/smtp/smtp.go b/smtp/smtp.go index d2a0e64..4ea1a3d 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -30,34 +30,57 @@ import ( "net/textproto" "os" "strings" + "sync" + "time" "github.com/wneessen/go-mail/log" ) // A Client represents a client connection to an SMTP server. type Client struct { - // Text is the textproto.Conn used by the Client. It is exported to allow for - // clients to add extensions. + // Text is the textproto.Conn used by the Client. It is exported to allow for clients to add extensions. Text *textproto.Conn - // keep a reference to the connection so it can be used to create a TLS - // connection later + + // auth supported auth mechanisms + auth []string + + // keep a reference to the connection so it can be used to create a TLS connection later conn net.Conn - // whether the Client is using TLS - tls bool - serverName string - // map of supported extensions + + // debug logging is enabled + debug bool + + // didHello indicates whether we've said HELO/EHLO + didHello bool + + // dsnmrtype defines the mail return option in case DSN is enabled + dsnmrtype string + + // dsnrntype defines the recipient notify option in case DSN is enabled + dsnrntype string + + // ext is a map of supported extensions ext map[string]string - // supported auth mechanisms - auth []string - localName string // the name to use in HELO/EHLO - didHello bool // whether we've said HELO/EHLO - helloError error // the error from the hello - // debug logging - debug bool // debug logging is enabled - logger log.Logger // logger will be used for debug logging - // DSN support - dsnmrtype string // dsnmrtype defines the mail return option in case DSN is enabled - dsnrntype string // dsnrntype defines the recipient notify option in case DSN is enabled + + // helloError is the error from the hello + helloError error + + // localName is the name to use in HELO/EHLO + localName string // the name to use in HELO/EHLO + + // logger will be used for debug logging + logger log.Logger + + // mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access + // the resource at a time. + mutex sync.RWMutex + + // tls indicates whether the Client is using TLS + tls bool + + // serverName denotes the name of the server to which the application will connect. Used for + // identification and routing. + serverName string } // Dial returns a new [Client] connected to an SMTP server at addr. @@ -94,7 +117,10 @@ func NewClient(conn net.Conn, host string) (*Client, error) { // Close closes the connection. func (c *Client) Close() error { - return c.Text.Close() + c.mutex.Lock() + err := c.Text.Close() + c.mutex.Unlock() + return err } // hello runs a hello exchange if needed. @@ -121,28 +147,39 @@ func (c *Client) Hello(localName string) error { if c.didHello { return errors.New("smtp: Hello called after other methods") } + + c.mutex.Lock() c.localName = localName + c.mutex.Unlock() + return c.hello() } // cmd is a convenience function that sends a command and returns the response func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { + c.mutex.Lock() + c.debugLog(log.DirClientToServer, format, args...) id, err := c.Text.Cmd(format, args...) if err != nil { + c.mutex.Unlock() return 0, "", err } c.Text.StartResponse(id) - defer c.Text.EndResponse(id) code, msg, err := c.Text.ReadResponse(expectCode) c.debugLog(log.DirServerToClient, "%d %s", code, msg) + c.Text.EndResponse(id) + c.mutex.Unlock() return code, msg, err } // helo sends the HELO greeting to the server. It should be used only when the // server does not support ehlo. func (c *Client) helo() error { + c.mutex.Lock() c.ext = nil + c.mutex.Unlock() + _, _, err := c.cmd(250, "HELO %s", c.localName) return err } @@ -157,9 +194,13 @@ func (c *Client) StartTLS(config *tls.Config) error { if err != nil { return err } + + c.mutex.Lock() c.conn = tls.Client(c.conn, config) c.Text = textproto.NewConn(c.conn) c.tls = true + c.mutex.Unlock() + return c.ehlo() } @@ -167,11 +208,15 @@ func (c *Client) StartTLS(config *tls.Config) error { // The return values are their zero values if [Client.StartTLS] did // not succeed. func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { + c.mutex.RLock() + defer c.mutex.RUnlock() + tc, ok := c.conn.(*tls.Conn) if !ok { return } - return tc.ConnectionState(), true + state, ok = tc.ConnectionState(), true + return } // Verify checks the validity of an email address on the server. @@ -257,6 +302,8 @@ func (c *Client) Mail(from string) error { return err } cmdStr := "MAIL FROM:<%s>" + + c.mutex.RLock() if c.ext != nil { if _, ok := c.ext["8BITMIME"]; ok { cmdStr += " BODY=8BITMIME" @@ -269,6 +316,8 @@ func (c *Client) Mail(from string) error { cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype) } } + c.mutex.RUnlock() + _, _, err := c.cmd(250, cmdStr, from) return err } @@ -280,7 +329,11 @@ func (c *Client) Rcpt(to string) error { if err := validateLine(to); err != nil { return err } + + c.mutex.RLock() _, ok := c.ext["DSN"] + c.mutex.RUnlock() + if ok && c.dsnrntype != "" { _, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype) return err @@ -294,12 +347,23 @@ type dataCloser struct { io.WriteCloser } +// Close releases the lock, closes the WriteCloser, waits for a response, and then returns any error encountered. func (d *dataCloser) Close() error { + d.c.mutex.Lock() _ = d.WriteCloser.Close() _, _, err := d.c.Text.ReadResponse(250) + d.c.mutex.Unlock() return err } +// Write writes data to the underlying WriteCloser while ensuring thread-safety by locking and unlocking a mutex. +func (d *dataCloser) Write(p []byte) (n int, err error) { + d.c.mutex.Lock() + n, err = d.WriteCloser.Write(p) + d.c.mutex.Unlock() + return +} + // Data issues a DATA command to the server and returns a writer that // can be used to write the mail headers and body. The caller should // close the writer before calling any more methods on c. A call to @@ -309,7 +373,14 @@ func (c *Client) Data() (io.WriteCloser, error) { if err != nil { return nil, err } - return &dataCloser{c, c.Text.DotWriter()}, nil + datacloser := &dataCloser{} + + c.mutex.Lock() + datacloser.c = c + datacloser.WriteCloser = c.Text.DotWriter() + c.mutex.Unlock() + + return datacloser, nil } var testHookStartTLS func(*tls.Config) // nil, except for tests @@ -405,7 +476,10 @@ func (c *Client) Extension(ext string) (bool, string) { return false, "" } ext = strings.ToUpper(ext) + + c.mutex.RLock() param, ok := c.ext[ext] + c.mutex.RUnlock() return ok, param } @@ -438,7 +512,11 @@ func (c *Client) Quit() error { if err != nil { return err } - return c.Text.Close() + c.mutex.Lock() + err = c.Text.Close() + c.mutex.Unlock() + + return err } // SetDebugLog enables the debug logging for incoming and outgoing SMTP messages @@ -472,6 +550,21 @@ func (c *Client) SetDSNRcptNotifyOption(d string) { c.dsnrntype = d } +// HasConnection checks if the client has an active connection. +// Returns true if the `conn` field is not nil, indicating an active connection. +func (c *Client) HasConnection() bool { + return c.conn != nil +} + +func (c *Client) UpdateDeadline(timeout time.Duration) error { + c.mutex.Lock() + if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return fmt.Errorf("smtp: failed to update deadline: %w", err) + } + c.mutex.Unlock() + return nil +} + // debugLog checks if the debug flag is set and if so logs the provided message to // the log.Logger interface func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) { diff --git a/smtp/smtp_ehlo.go b/smtp/smtp_ehlo.go index ae80a62..457be57 100644 --- a/smtp/smtp_ehlo.go +++ b/smtp/smtp_ehlo.go @@ -25,6 +25,9 @@ func (c *Client) ehlo() error { if err != nil { return err } + + c.mutex.Lock() + defer c.mutex.Unlock() ext := make(map[string]string) extList := strings.Split(msg, "\n") if len(extList) > 1 { diff --git a/smtp/smtp_ehlo_117.go b/smtp/smtp_ehlo_117.go index 429f30a..c40297f 100644 --- a/smtp/smtp_ehlo_117.go +++ b/smtp/smtp_ehlo_117.go @@ -22,12 +22,15 @@ import "strings" // should be the preferred greeting for servers that support it. // // Backport of: https://github.com/golang/go/commit/4d8db00641cc9ff4f44de7df9b8c4f4a4f9416ee#diff-4f6f6bdb9891d4dd271f9f31430420a2e44018fe4ee539576faf458bebb3cee4 -// to guarantee backwards compatibility with Go 1.16/1.17:w +// to guarantee backwards compatibility with Go 1.16/1.17 func (c *Client) ehlo() error { _, msg, err := c.cmd(250, "EHLO %s", c.localName) if err != nil { return err } + + c.mutex.Lock() + defer c.mutex.Unlock() ext := make(map[string]string) extList := strings.Split(msg, "\n") if len(extList) > 1 {