Compare commits

..

33 commits

Author SHA1 Message Date
Michael Fuchs
0db564ba6c
Merge 4b60557518 into 65a91a2711 2024-09-27 20:34:09 +02:00
65a91a2711
Merge pull request #307 from wneessen/feature/269_goroutineconcurrency-safety
go-mail goroutine-/thread-safety
2024-09-27 17:17:00 +02:00
c1f6ef07d4
Skip test cases when client creation fails
Updated the client creation check to skip test cases if the client cannot be created, instead of marking them as errors. This ensures tests dependent on a successful client creation do not fail unnecessarily but are instead skipped.
2024-09-27 17:09:00 +02:00
6e98d7e47d
Reduce message loop iterations and add XOAUTH2 tests
Loop iterations in `client_test.go` were reduced from 50 to 20 for efficiency. Added new tests to verify XOAUTH2 authentication support and error handling by simulating SMTP server responses.
2024-09-27 17:00:21 +02:00
8791ce5a33
Fix deferred mutex unlock in TLSConnectionState
Correct the sequence of mutex unlocking in TLSConnectionState to ensure the mutex is always released properly. This prevents potential deadlocks and ensures the function behaves as expected in a concurrent context.
2024-09-27 17:00:07 +02:00
2d98c40cb6
Add concurrent send tests for Client
Introduced TestClient_DialSendConcurrent_online and TestClient_DialSendConcurrent_local to validate concurrent sending of messages. These tests ensure that the Client's send functionality works correctly under concurrent conditions, both in an online environment and using a local test server.
2024-09-27 14:04:02 +02:00
253d065c83
Move mutex lock to sendSingleMsg method
Mutex locking was relocated from the Send method in client_120.go and client_119.go to sendSingleMsg in client.go. This ensures thread-safety specifically during the message transmission process.
2024-09-27 14:03:50 +02:00
6bd9a9c735
Refactor mutex usage for connection safety
This commit revises locking mechanism usage around connection operations to avoid potential deadlocks and improve code clarity. Specifically, defer statements were removed and explicit unlocks were added to ensure that mutexes are properly released after critical sections. This change affects several methods, including `Close`, `cmd`, `TLSConnectionState`, `UpdateDeadline`, and newly introduced locking for concurrent data writes and reads in `dataCloser`.
2024-09-27 14:03:26 +02:00
f59aa23ed8
Add mutex locking to SetTLSConfig
This change ensures that the SetTLSConfig method is thread-safe by adding a mutex lock. The lock is acquired before any changes to the TLS configuration and released afterward to prevent concurrent access issues.
2024-09-27 11:58:08 +02:00
2234f0c5bc
Remove connection field from Client struct
This commit removes the 'connection' field from the 'Client' struct and updates the related test logic accordingly. By using 'smtpClient.HasConnection()' to check for connections, code readability and maintainability are improved. All necessary test cases have been adjusted to reflect this change.
2024-09-27 11:43:22 +02:00
fdb80ad9dd
Add mutex to Client for thread-safe operations
This commit introduces a RWMutex to the Client struct in the smtp package to ensure thread-safe access to shared resources. Critical sections in methods like Close, StartTLS, and cmd are now protected with appropriate locking mechanisms. This change helps prevent potential race conditions, ensuring consistent and reliable behavior in concurrent environments.
2024-09-27 11:10:23 +02:00
fec2f2075a
Update build tags to support future Go versions
Modified the build tags to exclude Go 1.20 and above instead of targeting only Go 1.19. This change ensures the code is compatible with future versions of Go by not restricting it to a specific minor version.
2024-09-27 10:52:30 +02:00
2084526c77
Refactor Client struct to improve organization and clarity
Rearranged and grouped struct fields more logically within Client. Introduced the dialContextFunc and fallbackPort fields to enhance connection flexibility. Minor code style adjustments were also made for better readability.
2024-09-27 10:36:09 +02:00
23c71d608f
Lock mutex before checking connection in Send method
Added mutex locking in the `Send` method for both `client_120.go` and `client_119.go`. This ensures thread-safe access to the connection checks and prevents potential race conditions.
2024-09-27 10:33:28 +02:00
371b950bc7
Refactor Client struct for better readability and organization
Reordered and grouped fields in the Client struct for clarity. The reorganization separates logical groups of fields, making it easier to understand and maintain the code. This includes proper grouping of TLS parameters, DSN options, and debug settings.
2024-09-27 10:33:19 +02:00
b2c4b533d7
Merge branch 'main' into feature/269_goroutineconcurrency-safety 2024-09-27 10:26:39 +02:00
3871b2be44
Lock client connections and update deadline handling
Add mutex locking for client connections to ensure thread safety. Introduce `HasConnection` method to check active connections and `UpdateDeadline` method to handle timeout updates. Refactor connection handling in `checkConn` and `tls` methods accordingly.
2024-09-26 11:51:30 +02:00
8683917c3d
Delete connection pool implementation and tests
Remove `connpool.go` and `connpool_test.go`. This eliminates the connection pool feature from the codebase, including associated functionality and tests. The connection pool feature is much to complex and doesn't provide the benefits expected by the concurrency feature
2024-09-26 11:49:48 +02:00
fd115d5173
Remove typo from comment in smtp_ehlo_117.go
Fixed a typo in the backward compatibility comment for Go 1.16/1.17 in smtp_ehlo_117.go. This ensures clarity and correctness in documentation.
2024-09-23 14:15:43 +02:00
4f6224131e
Rename test for accurate context cancellation
Updated the test name from `TestConnPool_GetContextTimeout` to `TestConnPool_GetContextCancel` to better reflect its functionality. This change improves test readability and maintains consistency with the context usage in the test.
2024-09-23 13:46:41 +02:00
c8684886ed
Refactor Get method to include context argument
Updated the Get method in connpool.go and its usage in tests to include a context argument for better cancellation and timeout handling. Removed the redundant dialContext field from the connection pool struct and added a new test to validate context timeout behavior.
2024-09-23 13:44:03 +02:00
07e7b17ae8
Add tests for ConnPool close and concurrency issues
This commit introduces two new tests: `TestConnPool_Close` and `TestConnPool_Concurrency`. The former ensures the proper closing of connection pool resources, while the latter checks for concurrency issues by creating and closing multiple connections in parallel.
2024-09-23 11:45:22 +02:00
2abdee743d
Add unit test for marking connection as unusable
Introduces `TestPoolConn_MarkUnusable` to ensure the pool maintains its integrity when a connection is marked unusable. This test validates that the connection pool size adjusts correctly after marking a connection as unusable and closing it.
2024-09-23 11:27:05 +02:00
5503be8451
Remove redundant error print statements
Removed redundant fmt.Printf error print statements for connection read and write errors. This cleans up the test output and makes error handling more streamlined.
2024-09-23 11:26:54 +02:00
774925078a
Improve error handling in connection pool tests
Add error checks to Close() calls in connpool_test.go to ensure connection closures are handled properly, with descriptive error messages. Update comment in connpool.go to improve clarity on the source of code inspiration.
2024-09-23 11:17:58 +02:00
f1188bdad7
Add test for PoolConn Close method
This commit introduces a new test, `TestPoolConn_Close`, to verify that connections are correctly closed and returned to the pool. It sets up a simple SMTP server, creates a connection pool, tests writing to and closing connections, and checks the pool size to ensure proper behavior.
2024-09-23 11:15:56 +02:00
2cbd0c4aef
Add test cases for connection pool functionality
Added new test cases `TestConnPool_Get_Type` and `TestConnPool_Get` to verify connection pool operations. These tests ensure proper connection type and handling of pool size after connection retrieval and usage.
2024-09-23 11:09:11 +02:00
9a9e0c936d
Remove redundant error handling in test code
The check for io.EOF and the associated print statement were unnecessary because the loop breaks on any error. This change simplifies the error handling logic in the `client_test.go` file and avoids redundant code.
2024-09-23 11:09:03 +02:00
33d4eb5b21
Add unit tests for connection pool and rename Len to Size
Introduced unit tests for the connection pool to ensure robust functionality. Also, renamed the Len method to Size in the Pool interface and its implementation for better clarity and consistency.
2024-09-23 10:33:06 +02:00
d6e5034bba
Add new error handling and connection management in connpool
Introduce ErrClosed and ErrNilConn errors for better error handling. Implement Close and MarkUnusable methods for improved connection lifecycle management. Add put method to return connections to the pool or close them if necessary.
2024-09-23 10:09:38 +02:00
1394f1fc20
Add context management and error handling to connection pool
Introduced context support and enhanced error handling in connpool.go. Added detailed comments for better maintainability and introduced a wrapper for net.Conn to manage connection close behavior. The changes improve the robustness and clarity of the connection pool's operation.
2024-09-23 09:56:23 +02:00
26ff177fb0
Add connPool implementation and connection pool errors
Introduce connPool struct and implement the Pool interface. Add error handling for invalid pool capacity settings and provide a constructor for creating new connection pools with specified capacities.
2024-09-22 21:12:59 +02:00
5ec8a5a5fe
Add initial connection pool interface
Introduces a new `connpool.go` file implementing a connection pool interface for managing network connections. This interface includes methods to get and close connections, as well as to retrieve the current pool size. The implementation is initially based on a fork of code from the Fatih Arslan GitHub repository.
2024-09-22 20:48:09 +02:00
6 changed files with 389 additions and 101 deletions

View file

@ -12,6 +12,7 @@ import (
"net" "net"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
"github.com/wneessen/go-mail/log" "github.com/wneessen/go-mail/log"
@ -87,12 +88,12 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
// Client is the SMTP client struct // Client is the SMTP client struct
type Client struct { type Client struct {
// connection is the net.Conn that the smtp.Client is based on
connection net.Conn
// Timeout for the SMTP server connection // Timeout for the SMTP server connection
connTimeout time.Duration connTimeout time.Duration
// dialContextFunc is a custom DialContext function to dial target SMTP server
dialContextFunc DialContextFunc
// dsn indicates that we want to use DSN for the Client // dsn indicates that we want to use DSN for the Client
dsn bool dsn bool
@ -102,11 +103,9 @@ type Client struct {
// dsnrntype defines the DSNRcptNotifyOption in case DSN is enabled // dsnrntype defines the DSNRcptNotifyOption in case DSN is enabled
dsnrntype []string dsnrntype []string
// isEncrypted indicates if a Client connection is encrypted or not // fallbackPort is used as an alternative port number in case the primary port is unavailable or
isEncrypted bool // fails to bind.
fallbackPort int
// noNoop indicates the Noop is to be skipped
noNoop bool
// HELO/EHLO string for the greeting the target SMTP server // HELO/EHLO string for the greeting the target SMTP server
helo string helo string
@ -114,12 +113,24 @@ type Client struct {
// Hostname of the target SMTP server to connect to // Hostname of the target SMTP server to connect to
host string host string
// isEncrypted indicates if a Client connection is encrypted or not
isEncrypted bool
// logger is a logger that implements the log.Logger interface
logger log.Logger
// mutex is used to synchronize access to shared resources, ensuring that only one goroutine can
// modify them at a time.
mutex sync.RWMutex
// noNoop indicates the Noop is to be skipped
noNoop bool
// pass is the corresponding SMTP AUTH password // pass is the corresponding SMTP AUTH password
pass string pass string
// Port of the SMTP server to connect to // port specifies the network port number on which the server listens for incoming connections.
port int port int
fallbackPort int
// smtpAuth is a pointer to smtp.Auth // smtpAuth is a pointer to smtp.Auth
smtpAuth smtp.Auth smtpAuth smtp.Auth
@ -130,26 +141,20 @@ type Client struct {
// smtpClient is the smtp.Client that is set up when using the Dial*() methods // smtpClient is the smtp.Client that is set up when using the Dial*() methods
smtpClient *smtp.Client smtpClient *smtp.Client
// Use SSL for the connection
useSSL bool
// tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol // tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol
tlspolicy TLSPolicy tlspolicy TLSPolicy
// tlsconfig represents the tls.Config setting for the STARTTLS connection // tlsconfig represents the tls.Config setting for the STARTTLS connection
tlsconfig *tls.Config tlsconfig *tls.Config
// user is the SMTP AUTH username
user string
// useDebugLog enables the debug logging on the SMTP client // useDebugLog enables the debug logging on the SMTP client
useDebugLog bool useDebugLog bool
// logger is a logger that implements the log.Logger interface // user is the SMTP AUTH username
logger log.Logger user string
// dialContextFunc is a custom DialContext function to dial target SMTP server // Use SSL for the connection
dialContextFunc DialContextFunc useSSL bool
} }
// Option returns a function that can be used for grouping Client options // Option returns a function that can be used for grouping Client options
@ -550,6 +555,9 @@ func (c *Client) SetLogger(logger log.Logger) {
// SetTLSConfig overrides the current *tls.Config with the given *tls.Config value // SetTLSConfig overrides the current *tls.Config with the given *tls.Config value
func (c *Client) SetTLSConfig(tlsconfig *tls.Config) error { func (c *Client) SetTLSConfig(tlsconfig *tls.Config) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if tlsconfig == nil { if tlsconfig == nil {
return ErrInvalidTLSConfig return ErrInvalidTLSConfig
} }
@ -589,6 +597,9 @@ func (c *Client) setDefaultHelo() error {
// DialWithContext establishes a connection to the SMTP server with a given context.Context // DialWithContext establishes a connection to the SMTP server with a given context.Context
func (c *Client) DialWithContext(dialCtx context.Context) error { func (c *Client) DialWithContext(dialCtx context.Context) error {
c.mutex.Lock()
defer c.mutex.Unlock()
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
defer cancel() defer cancel()
@ -602,17 +613,16 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
c.dialContextFunc = tlsDialer.DialContext c.dialContextFunc = tlsDialer.DialContext
} }
} }
var err error connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr())
c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 { if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error? // TODO: should we somehow log or append the previous error?
c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
} }
if err != nil { if err != nil {
return err return err
} }
client, err := smtp.NewClient(c.connection, c.host) client, err := smtp.NewClient(connection, c.host)
if err != nil { if err != nil {
return err return err
} }
@ -691,7 +701,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
// checkConn makes sure that a required server connection is available and extends the // checkConn makes sure that a required server connection is available and extends the
// connection deadline // connection deadline
func (c *Client) checkConn() error { func (c *Client) checkConn() error {
if c.connection == nil { if !c.smtpClient.HasConnection() {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
@ -701,7 +711,7 @@ func (c *Client) checkConn() error {
} }
} }
if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil { if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil {
return ErrDeadlineExtendFailed return ErrDeadlineExtendFailed
} }
return nil return nil
@ -715,7 +725,7 @@ func (c *Client) serverFallbackAddr() string {
// tls tries to make sure that the STARTTLS requirements are satisfied // tls tries to make sure that the STARTTLS requirements are satisfied
func (c *Client) tls() error { func (c *Client) tls() error {
if c.connection == nil { if !c.smtpClient.HasConnection() {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
if !c.useSSL && c.tlspolicy != NoTLS { if !c.useSSL && c.tlspolicy != NoTLS {
@ -791,6 +801,9 @@ func (c *Client) auth() error {
// sendSingleMsg sends out a single message and returns an error if the transmission/delivery fails. // sendSingleMsg sends out a single message and returns an error if the transmission/delivery fails.
// It is invoked by the public Send methods // It is invoked by the public Send methods
func (c *Client) sendSingleMsg(message *Msg) error { func (c *Client) sendSingleMsg(message *Msg) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if message.encoding == NoEncoding { if message.encoding == NoEncoding {
if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok { if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok {
return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message} return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message}

View file

@ -15,6 +15,7 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -623,11 +624,12 @@ func TestClient_DialWithContext(t *testing.T) {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if c.smtpClient == nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
}
if !c.smtpClient.HasConnection() {
t.Errorf("DialWithContext didn't fail but no connection found.")
} }
if err := c.Close(); err != nil { if err := c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err) t.Errorf("failed to close connection: %s", err)
@ -644,17 +646,18 @@ func TestClient_DialWithContext_Fallback(t *testing.T) {
c.SetTLSPortPolicy(TLSOpportunistic) c.SetTLSPortPolicy(TLSOpportunistic)
c.port = 999 c.port = 999
ctx := context.Background() ctx := context.Background()
if err := c.DialWithContext(ctx); err != nil { if err = c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if c.smtpClient == nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
} }
if err := c.Close(); err != nil { if !c.smtpClient.HasConnection() {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if err = c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err) t.Errorf("failed to close connection: %s", err)
} }
@ -674,18 +677,19 @@ func TestClient_DialWithContext_Debug(t *testing.T) {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
ctx := context.Background() ctx := context.Background()
if err := c.DialWithContext(ctx); err != nil { if err = c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if c.smtpClient == nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
}
if !c.smtpClient.HasConnection() {
t.Errorf("DialWithContext didn't fail but no connection found.")
} }
c.SetDebugLog(true) c.SetDebugLog(true)
if err := c.Close(); err != nil { if err = c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err) t.Errorf("failed to close connection: %s", err)
} }
} }
@ -698,19 +702,20 @@ func TestClient_DialWithContext_Debug_custom(t *testing.T) {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
ctx := context.Background() ctx := context.Background()
if err := c.DialWithContext(ctx); err != nil { if err = c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if c.smtpClient == nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
}
if !c.smtpClient.HasConnection() {
t.Errorf("DialWithContext didn't fail but no connection found.")
} }
c.SetDebugLog(true) c.SetDebugLog(true)
c.SetLogger(log.New(os.Stderr, log.LevelDebug)) c.SetLogger(log.New(os.Stderr, log.LevelDebug))
if err := c.Close(); err != nil { if err = c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err) t.Errorf("failed to close connection: %s", err)
} }
} }
@ -722,10 +727,9 @@ func TestClient_DialWithContextInvalidHost(t *testing.T) {
if err != nil { if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
c.connection = nil
c.host = "invalid.addr" c.host = "invalid.addr"
ctx := context.Background() ctx := context.Background()
if err := c.DialWithContext(ctx); err == nil { if err = c.DialWithContext(ctx); err == nil {
t.Errorf("dial succeeded but was supposed to fail") t.Errorf("dial succeeded but was supposed to fail")
return return
} }
@ -738,10 +742,9 @@ func TestClient_DialWithContextInvalidHELO(t *testing.T) {
if err != nil { if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
c.connection = nil
c.helo = "" c.helo = ""
ctx := context.Background() ctx := context.Background()
if err := c.DialWithContext(ctx); err == nil { if err = c.DialWithContext(ctx); err == nil {
t.Errorf("dial succeeded but was supposed to fail") t.Errorf("dial succeeded but was supposed to fail")
return return
} }
@ -758,7 +761,7 @@ func TestClient_DialWithContextInvalidAuth(t *testing.T) {
c.pass = "invalid" c.pass = "invalid"
c.SetSMTPAuthCustom(smtp.LoginAuth("invalid", "invalid", "invalid")) c.SetSMTPAuthCustom(smtp.LoginAuth("invalid", "invalid", "invalid"))
ctx := context.Background() ctx := context.Background()
if err := c.DialWithContext(ctx); err == nil { if err = c.DialWithContext(ctx); err == nil {
t.Errorf("dial succeeded but was supposed to fail") t.Errorf("dial succeeded but was supposed to fail")
return return
} }
@ -770,8 +773,7 @@ func TestClient_checkConn(t *testing.T) {
if err != nil { if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
c.connection = nil if err = c.checkConn(); err == nil {
if err := c.checkConn(); err == nil {
t.Errorf("connCheck() should fail but succeeded") t.Errorf("connCheck() should fail but succeeded")
} }
} }
@ -802,21 +804,23 @@ func TestClient_DialWithContextOptions(t *testing.T) {
} }
ctx := context.Background() ctx := context.Background()
if err := c.DialWithContext(ctx); err != nil && !tt.sf { if err = c.DialWithContext(ctx); err != nil && !tt.sf {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if !tt.sf { if !tt.sf {
if c.connection == nil && !tt.sf {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if c.smtpClient == nil && !tt.sf { if c.smtpClient == nil && !tt.sf {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
} }
if err := c.Reset(); err != nil { if !c.smtpClient.HasConnection() && !tt.sf {
t.Errorf("DialWithContext didn't fail but no connection found.")
return
}
if err = c.Reset(); err != nil {
t.Errorf("failed to reset connection: %s", err) t.Errorf("failed to reset connection: %s", err)
} }
if err := c.Close(); err != nil { if err = c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err) t.Errorf("failed to close connection: %s", err)
} }
} }
@ -1011,17 +1015,15 @@ func TestClient_DialSendCloseBroken(t *testing.T) {
} }
if tt.closestart { if tt.closestart {
_ = c.smtpClient.Close() _ = c.smtpClient.Close()
_ = c.connection.Close()
} }
if err := c.Send(m); err != nil && !tt.sf { if err = c.Send(m); err != nil && !tt.sf {
t.Errorf("Send() failed: %s", err) t.Errorf("Send() failed: %s", err)
return return
} }
if tt.closeearly { if tt.closeearly {
_ = c.smtpClient.Close() _ = c.smtpClient.Close()
_ = c.connection.Close()
} }
if err := c.Close(); err != nil && !tt.sf { if err = c.Close(); err != nil && !tt.sf {
t.Errorf("Close() failed: %s", err) t.Errorf("Close() failed: %s", err)
return return
} }
@ -1071,17 +1073,15 @@ func TestClient_DialSendCloseBrokenWithDSN(t *testing.T) {
} }
if tt.closestart { if tt.closestart {
_ = c.smtpClient.Close() _ = c.smtpClient.Close()
_ = c.connection.Close()
} }
if err := c.Send(m); err != nil && !tt.sf { if err = c.Send(m); err != nil && !tt.sf {
t.Errorf("Send() failed: %s", err) t.Errorf("Send() failed: %s", err)
return return
} }
if tt.closeearly { if tt.closeearly {
_ = c.smtpClient.Close() _ = c.smtpClient.Close()
_ = c.connection.Close()
} }
if err := c.Close(); err != nil && !tt.sf { if err = c.Close(); err != nil && !tt.sf {
t.Errorf("Close() failed: %s", err) t.Errorf("Close() failed: %s", err)
return return
} }
@ -1728,6 +1728,114 @@ func TestClient_SendErrorReset(t *testing.T) {
} }
} }
func TestClient_DialSendConcurrent_online(t *testing.T) {
if os.Getenv("TEST_ALLOW_SEND") == "" {
t.Skipf("TEST_ALLOW_SEND is not set. Skipping mail sending test")
}
client, err := getTestConnection(true)
if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err)
}
var messages []*Msg
for i := 0; i < 10; i++ {
message := NewMsg()
if err := message.FromFormat("go-mail Test Mailer", os.Getenv("TEST_FROM")); err != nil {
t.Errorf("failed to set FROM address: %s", err)
return
}
if err := message.To(TestRcpt); err != nil {
t.Errorf("failed to set TO address: %s", err)
return
}
message.Subject(fmt.Sprintf("Test subject for mail %d", i))
message.SetBodyString(TypeTextPlain, fmt.Sprintf("This is the test body of the mail no. %d", i))
message.SetMessageID()
messages = append(messages, message)
}
if err = client.DialWithContext(context.Background()); err != nil {
t.Errorf("failed to dial to test server: %s", err)
}
wg := sync.WaitGroup{}
for id, message := range messages {
wg.Add(1)
go func(curMsg *Msg, curID int) {
defer wg.Done()
if goroutineErr := client.Send(curMsg); err != nil {
t.Errorf("failed to send message with ID %d: %s", curID, goroutineErr)
}
}(message, id)
}
wg.Wait()
if err = client.Close(); err != nil {
t.Errorf("failed to close server connection: %s", err)
}
}
func TestClient_DialSendConcurrent_local(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 20
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 500)
client, err := NewClient(TestServerAddr, WithPort(serverPort),
WithTLSPortPolicy(NoTLS), WithSMTPAuth(SMTPAuthPlain),
WithUsername("toni@tester.com"),
WithPassword("V3ryS3cr3t+"))
if err != nil {
t.Errorf("unable to create new client: %s", err)
}
var messages []*Msg
for i := 0; i < 20; i++ {
message := NewMsg()
if err := message.From("valid-from@domain.tld"); err != nil {
t.Errorf("failed to set FROM address: %s", err)
return
}
if err := message.To("valid-to@domain.tld"); err != nil {
t.Errorf("failed to set TO address: %s", err)
return
}
message.Subject("Test subject")
message.SetBodyString(TypeTextPlain, "Test body")
message.SetMessageIDWithValue("this.is.a.message.id")
messages = append(messages, message)
}
if err = client.DialWithContext(context.Background()); err != nil {
t.Errorf("failed to dial to test server: %s", err)
}
wg := sync.WaitGroup{}
for id, message := range messages {
wg.Add(1)
go func(curMsg *Msg, curID int) {
defer wg.Done()
if goroutineErr := client.Send(curMsg); err != nil {
t.Errorf("failed to send message with ID %d: %s", curID, goroutineErr)
}
}(message, id)
}
wg.Wait()
if err = client.Close(); err != nil {
t.Errorf("failed to close server connection: %s", err)
}
}
// getTestConnection takes environment variables to establish a connection to a real // getTestConnection takes environment variables to establish a connection to a real
// SMTP server to test all functionality that requires a connection // SMTP server to test all functionality that requires a connection
func getTestConnection(auth bool) (*Client, error) { func getTestConnection(auth bool) (*Client, error) {
@ -1913,6 +2021,72 @@ func getTestConnectionWithDSN(auth bool) (*Client, error) {
} }
func TestXOAuth2OK(t *testing.T) { func TestXOAuth2OK(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 30
featureSet := "250-AUTH XOAUTH2\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 500)
c, err := NewClient("127.0.0.1",
WithPort(serverPort),
WithTLSPortPolicy(TLSOpportunistic),
WithSMTPAuth(SMTPAuthXOAUTH2),
WithUsername("user"),
WithPassword("token"))
if err != nil {
t.Fatalf("unable to create new client: %v", err)
}
if err = c.DialWithContext(context.Background()); err != nil {
t.Fatalf("unexpected dial error: %v", err)
}
if err = c.Close(); err != nil {
t.Fatalf("disconnect from test server failed: %v", err)
}
}
func TestXOAuth2Unsupported(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 31
featureSet := "250-AUTH LOGIN PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 500)
c, err := NewClient("127.0.0.1",
WithPort(serverPort),
WithTLSPolicy(TLSOpportunistic),
WithSMTPAuth(SMTPAuthXOAUTH2),
WithUsername("user"),
WithPassword("token"))
if err != nil {
t.Fatalf("unable to create new client: %v", err)
}
if err = c.DialWithContext(context.Background()); err == nil {
t.Fatal("expected dial error got nil")
} else {
if !errors.Is(err, ErrXOauth2AuthNotSupported) {
t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err)
}
}
if err = c.Close(); err != nil {
t.Fatalf("disconnect from test server failed: %v", err)
}
}
func TestXOAuth2OK_faker(t *testing.T) {
server := []string{ server := []string{
"220 Fake server ready ESMTP", "220 Fake server ready ESMTP",
"250-fake.server", "250-fake.server",
@ -1952,7 +2126,7 @@ func TestXOAuth2OK(t *testing.T) {
} }
} }
func TestXOAuth2Unsupported(t *testing.T) { func TestXOAuth2Unsupported_faker(t *testing.T) {
server := []string{ server := []string{
"220 Fake server ready ESMTP", "220 Fake server ready ESMTP",
"250-fake.server", "250-fake.server",
@ -2085,7 +2259,6 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
data, err := reader.ReadString('\n') data, err := reader.ReadString('\n')
if err != nil { if err != nil {
fmt.Printf("unable to read from connection: %s\n", err)
return return
} }
if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") { if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") {
@ -2093,19 +2266,15 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
return return
} }
if err = writeLine("250-localhost.localdomain\r\n" + featureSet); err != nil { if err = writeLine("250-localhost.localdomain\r\n" + featureSet); err != nil {
fmt.Printf("unable to write to connection: %s\n", err)
return return
} }
for { for {
data, err = reader.ReadString('\n') data, err = reader.ReadString('\n')
if err != nil { if err != nil {
if errors.Is(err, io.EOF) {
break
}
fmt.Println("Error reading data:", err)
break break
} }
time.Sleep(time.Millisecond)
var datastring string var datastring string
data = strings.TrimSpace(data) data = strings.TrimSpace(data)
@ -2128,6 +2297,13 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
break break
} }
writeOK() writeOK()
case strings.HasPrefix(data, "AUTH XOAUTH2"):
auth := strings.TrimPrefix(data, "AUTH XOAUTH2 ")
if !strings.EqualFold(auth, "dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=") {
_ = writeLine("535 5.7.8 Error: authentication failed")
break
}
_ = writeLine("235 2.7.0 Authentication successful")
case strings.HasPrefix(data, "AUTH PLAIN"): case strings.HasPrefix(data, "AUTH PLAIN"):
auth := strings.TrimPrefix(data, "AUTH PLAIN ") auth := strings.TrimPrefix(data, "AUTH PLAIN ")
if !strings.EqualFold(auth, "AHRvbmlAdGVzdGVyLmNvbQBWM3J5UzNjcjN0Kw==") { if !strings.EqualFold(auth, "AHRvbmlAdGVzdGVyLmNvbQBWM3J5UzNjcjN0Kw==") {

View file

@ -2,8 +2,8 @@
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
//go:build go1.19 && !go1.20 //go:build !go1.20
// +build go1.19,!go1.20 // +build !go1.20
package mail package mail

View file

@ -30,34 +30,57 @@ import (
"net/textproto" "net/textproto"
"os" "os"
"strings" "strings"
"sync"
"time"
"github.com/wneessen/go-mail/log" "github.com/wneessen/go-mail/log"
) )
// A Client represents a client connection to an SMTP server. // A Client represents a client connection to an SMTP server.
type Client struct { type Client struct {
// Text is the textproto.Conn used by the Client. It is exported to allow for // Text is the textproto.Conn used by the Client. It is exported to allow for clients to add extensions.
// clients to add extensions.
Text *textproto.Conn Text *textproto.Conn
// keep a reference to the connection so it can be used to create a TLS
// connection later // auth supported auth mechanisms
conn net.Conn
// whether the Client is using TLS
tls bool
serverName string
// map of supported extensions
ext map[string]string
// supported auth mechanisms
auth []string auth []string
// keep a reference to the connection so it can be used to create a TLS connection later
conn net.Conn
// debug logging is enabled
debug bool
// didHello indicates whether we've said HELO/EHLO
didHello bool
// dsnmrtype defines the mail return option in case DSN is enabled
dsnmrtype string
// dsnrntype defines the recipient notify option in case DSN is enabled
dsnrntype string
// ext is a map of supported extensions
ext map[string]string
// helloError is the error from the hello
helloError error
// localName is the name to use in HELO/EHLO
localName string // the name to use in HELO/EHLO localName string // the name to use in HELO/EHLO
didHello bool // whether we've said HELO/EHLO
helloError error // the error from the hello // logger will be used for debug logging
// debug logging logger log.Logger
debug bool // debug logging is enabled
logger log.Logger // logger will be used for debug logging // mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access
// DSN support // the resource at a time.
dsnmrtype string // dsnmrtype defines the mail return option in case DSN is enabled mutex sync.RWMutex
dsnrntype string // dsnrntype defines the recipient notify option in case DSN is enabled
// tls indicates whether the Client is using TLS
tls bool
// serverName denotes the name of the server to which the application will connect. Used for
// identification and routing.
serverName string
} }
// Dial returns a new [Client] connected to an SMTP server at addr. // Dial returns a new [Client] connected to an SMTP server at addr.
@ -94,7 +117,10 @@ func NewClient(conn net.Conn, host string) (*Client, error) {
// Close closes the connection. // Close closes the connection.
func (c *Client) Close() error { func (c *Client) Close() error {
return c.Text.Close() c.mutex.Lock()
err := c.Text.Close()
c.mutex.Unlock()
return err
} }
// hello runs a hello exchange if needed. // hello runs a hello exchange if needed.
@ -121,28 +147,39 @@ func (c *Client) Hello(localName string) error {
if c.didHello { if c.didHello {
return errors.New("smtp: Hello called after other methods") return errors.New("smtp: Hello called after other methods")
} }
c.mutex.Lock()
c.localName = localName c.localName = localName
c.mutex.Unlock()
return c.hello() return c.hello()
} }
// cmd is a convenience function that sends a command and returns the response // cmd is a convenience function that sends a command and returns the response
func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) {
c.mutex.Lock()
c.debugLog(log.DirClientToServer, format, args...) c.debugLog(log.DirClientToServer, format, args...)
id, err := c.Text.Cmd(format, args...) id, err := c.Text.Cmd(format, args...)
if err != nil { if err != nil {
c.mutex.Unlock()
return 0, "", err return 0, "", err
} }
c.Text.StartResponse(id) c.Text.StartResponse(id)
defer c.Text.EndResponse(id)
code, msg, err := c.Text.ReadResponse(expectCode) code, msg, err := c.Text.ReadResponse(expectCode)
c.debugLog(log.DirServerToClient, "%d %s", code, msg) c.debugLog(log.DirServerToClient, "%d %s", code, msg)
c.Text.EndResponse(id)
c.mutex.Unlock()
return code, msg, err return code, msg, err
} }
// helo sends the HELO greeting to the server. It should be used only when the // helo sends the HELO greeting to the server. It should be used only when the
// server does not support ehlo. // server does not support ehlo.
func (c *Client) helo() error { func (c *Client) helo() error {
c.mutex.Lock()
c.ext = nil c.ext = nil
c.mutex.Unlock()
_, _, err := c.cmd(250, "HELO %s", c.localName) _, _, err := c.cmd(250, "HELO %s", c.localName)
return err return err
} }
@ -157,9 +194,13 @@ func (c *Client) StartTLS(config *tls.Config) error {
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock()
c.conn = tls.Client(c.conn, config) c.conn = tls.Client(c.conn, config)
c.Text = textproto.NewConn(c.conn) c.Text = textproto.NewConn(c.conn)
c.tls = true c.tls = true
c.mutex.Unlock()
return c.ehlo() return c.ehlo()
} }
@ -167,11 +208,15 @@ func (c *Client) StartTLS(config *tls.Config) error {
// The return values are their zero values if [Client.StartTLS] did // The return values are their zero values if [Client.StartTLS] did
// not succeed. // not succeed.
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
tc, ok := c.conn.(*tls.Conn) tc, ok := c.conn.(*tls.Conn)
if !ok { if !ok {
return return
} }
return tc.ConnectionState(), true state, ok = tc.ConnectionState(), true
return
} }
// Verify checks the validity of an email address on the server. // Verify checks the validity of an email address on the server.
@ -257,6 +302,8 @@ func (c *Client) Mail(from string) error {
return err return err
} }
cmdStr := "MAIL FROM:<%s>" cmdStr := "MAIL FROM:<%s>"
c.mutex.RLock()
if c.ext != nil { if c.ext != nil {
if _, ok := c.ext["8BITMIME"]; ok { if _, ok := c.ext["8BITMIME"]; ok {
cmdStr += " BODY=8BITMIME" cmdStr += " BODY=8BITMIME"
@ -269,6 +316,8 @@ func (c *Client) Mail(from string) error {
cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype) cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype)
} }
} }
c.mutex.RUnlock()
_, _, err := c.cmd(250, cmdStr, from) _, _, err := c.cmd(250, cmdStr, from)
return err return err
} }
@ -280,7 +329,11 @@ func (c *Client) Rcpt(to string) error {
if err := validateLine(to); err != nil { if err := validateLine(to); err != nil {
return err return err
} }
c.mutex.RLock()
_, ok := c.ext["DSN"] _, ok := c.ext["DSN"]
c.mutex.RUnlock()
if ok && c.dsnrntype != "" { if ok && c.dsnrntype != "" {
_, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype) _, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype)
return err return err
@ -294,12 +347,23 @@ type dataCloser struct {
io.WriteCloser io.WriteCloser
} }
// Close releases the lock, closes the WriteCloser, waits for a response, and then returns any error encountered.
func (d *dataCloser) Close() error { func (d *dataCloser) Close() error {
d.c.mutex.Lock()
_ = d.WriteCloser.Close() _ = d.WriteCloser.Close()
_, _, err := d.c.Text.ReadResponse(250) _, _, err := d.c.Text.ReadResponse(250)
d.c.mutex.Unlock()
return err return err
} }
// Write writes data to the underlying WriteCloser while ensuring thread-safety by locking and unlocking a mutex.
func (d *dataCloser) Write(p []byte) (n int, err error) {
d.c.mutex.Lock()
n, err = d.WriteCloser.Write(p)
d.c.mutex.Unlock()
return
}
// Data issues a DATA command to the server and returns a writer that // Data issues a DATA command to the server and returns a writer that
// can be used to write the mail headers and body. The caller should // can be used to write the mail headers and body. The caller should
// close the writer before calling any more methods on c. A call to // close the writer before calling any more methods on c. A call to
@ -309,7 +373,14 @@ func (c *Client) Data() (io.WriteCloser, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &dataCloser{c, c.Text.DotWriter()}, nil datacloser := &dataCloser{}
c.mutex.Lock()
datacloser.c = c
datacloser.WriteCloser = c.Text.DotWriter()
c.mutex.Unlock()
return datacloser, nil
} }
var testHookStartTLS func(*tls.Config) // nil, except for tests var testHookStartTLS func(*tls.Config) // nil, except for tests
@ -405,7 +476,10 @@ func (c *Client) Extension(ext string) (bool, string) {
return false, "" return false, ""
} }
ext = strings.ToUpper(ext) ext = strings.ToUpper(ext)
c.mutex.RLock()
param, ok := c.ext[ext] param, ok := c.ext[ext]
c.mutex.RUnlock()
return ok, param return ok, param
} }
@ -438,7 +512,11 @@ func (c *Client) Quit() error {
if err != nil { if err != nil {
return err return err
} }
return c.Text.Close() c.mutex.Lock()
err = c.Text.Close()
c.mutex.Unlock()
return err
} }
// SetDebugLog enables the debug logging for incoming and outgoing SMTP messages // SetDebugLog enables the debug logging for incoming and outgoing SMTP messages
@ -472,6 +550,21 @@ func (c *Client) SetDSNRcptNotifyOption(d string) {
c.dsnrntype = d c.dsnrntype = d
} }
// HasConnection checks if the client has an active connection.
// Returns true if the `conn` field is not nil, indicating an active connection.
func (c *Client) HasConnection() bool {
return c.conn != nil
}
func (c *Client) UpdateDeadline(timeout time.Duration) error {
c.mutex.Lock()
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return fmt.Errorf("smtp: failed to update deadline: %w", err)
}
c.mutex.Unlock()
return nil
}
// debugLog checks if the debug flag is set and if so logs the provided message to // debugLog checks if the debug flag is set and if so logs the provided message to
// the log.Logger interface // the log.Logger interface
func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) { func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) {

View file

@ -25,6 +25,9 @@ func (c *Client) ehlo() error {
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock()
defer c.mutex.Unlock()
ext := make(map[string]string) ext := make(map[string]string)
extList := strings.Split(msg, "\n") extList := strings.Split(msg, "\n")
if len(extList) > 1 { if len(extList) > 1 {

View file

@ -22,12 +22,15 @@ import "strings"
// should be the preferred greeting for servers that support it. // should be the preferred greeting for servers that support it.
// //
// Backport of: https://github.com/golang/go/commit/4d8db00641cc9ff4f44de7df9b8c4f4a4f9416ee#diff-4f6f6bdb9891d4dd271f9f31430420a2e44018fe4ee539576faf458bebb3cee4 // Backport of: https://github.com/golang/go/commit/4d8db00641cc9ff4f44de7df9b8c4f4a4f9416ee#diff-4f6f6bdb9891d4dd271f9f31430420a2e44018fe4ee539576faf458bebb3cee4
// to guarantee backwards compatibility with Go 1.16/1.17:w // to guarantee backwards compatibility with Go 1.16/1.17
func (c *Client) ehlo() error { func (c *Client) ehlo() error {
_, msg, err := c.cmd(250, "EHLO %s", c.localName) _, msg, err := c.cmd(250, "EHLO %s", c.localName)
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock()
defer c.mutex.Unlock()
ext := make(map[string]string) ext := make(map[string]string)
extList := strings.Split(msg, "\n") extList := strings.Split(msg, "\n")
if len(extList) > 1 { if len(extList) > 1 {