From 5ec8a5a5feef5a66cf4617f1e94097952a1e15ae Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sun, 22 Sep 2024 20:48:09 +0200 Subject: [PATCH 01/30] Add initial connection pool interface Introduces a new `connpool.go` file implementing a connection pool interface for managing network connections. This interface includes methods to get and close connections, as well as to retrieve the current pool size. The implementation is initially based on a fork of code from the Fatih Arslan GitHub repository. --- connpool.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 connpool.go diff --git a/connpool.go b/connpool.go new file mode 100644 index 0000000..8eabbad --- /dev/null +++ b/connpool.go @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors +// +// SPDX-License-Identifier: MIT + +package mail + +import "net" + +// Parts of the connection pool code is forked from https://github.com/fatih/pool/ +// Thanks to Fatih Arslan and the project contributors for providing this great +// concurrency template. + +// Pool interface describes a connection pool implementation. A Pool is +// thread-/go-routine safe. +type Pool interface { + // Get returns a new connection from the pool. Closing the connections returns + // it back into the Pool. Closing a connection when the Pool is destroyed or + // full will be counted as an error. + Get() (net.Conn, error) + + // Close closes the pool and all its connections. After Close() the pool is + // no longer usable. + Close() + + // Len returns the current number of connections of the pool. + Len() int +} From 26ff177fb0f91f2afc71c6692b95726e477ec02a Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Sun, 22 Sep 2024 21:12:59 +0200 Subject: [PATCH 02/30] Add connPool implementation and connection pool errors Introduce connPool struct and implement the Pool interface. Add error handling for invalid pool capacity settings and provide a constructor for creating new connection pools with specified capacities. --- connpool.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/connpool.go b/connpool.go index 8eabbad..96ba915 100644 --- a/connpool.go +++ b/connpool.go @@ -4,12 +4,20 @@ package mail -import "net" +import ( + "errors" + "net" + "sync" +) // Parts of the connection pool code is forked from https://github.com/fatih/pool/ // Thanks to Fatih Arslan and the project contributors for providing this great // concurrency template. +var ( + ErrPoolInvalidCap = errors.New("invalid connection pool capacity settings") +) + // Pool interface describes a connection pool implementation. A Pool is // thread-/go-routine safe. type Pool interface { @@ -25,3 +33,57 @@ type Pool interface { // Len returns the current number of connections of the pool. Len() int } + +// connPool implements the Pool interface +type connPool struct { + // mutex is used to synchronize access to the connection pool to ensure thread-safe operations + mutex sync.RWMutex + // conns is a channel used to manage and distribute net.Conn objects within the connection pool + conns chan net.Conn + // dialCtx represents the actual net.Conn returned by the DialContextFunc + dialCtx DialContextFunc +} + +// NewConnPool returns a new pool based on buffered channels with an initial +// capacity and maximum capacity. The DialContextFunc is used when the initial +// capacity is greater than zero to fill the pool. A zero initialCap doesn't +// fill the Pool until a new Get() is called. During a Get(), if there is no +// new connection available in the pool, a new connection will be created via +// the corresponding DialContextFunc() method. +func NewConnPool(initialCap, maxCap int, dialCtxFunc DialContextFunc) (Pool, error) { + if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { + return nil, ErrPoolInvalidCap + } + + pool := &connPool{ + conns: make(chan net.Conn, maxCap), + dialCtx: dialCtxFunc, + } + + // create initial connections, if something goes wrong, + // just close the pool error out. + for i := 0; i < initialCap; i++ { + /* + conn, err := dialCtxFunc() + if err != nil { + pool.Close() + return nil, fmt.Errorf("factory is not able to fill the pool: %s", err) + } + c.conns <- conn + + */ + } + + return pool, nil +} + +func (c *connPool) Get() (net.Conn, error) { + return nil, nil +} +func (c *connPool) Close() { + return +} + +func (c *connPool) Len() int { + return 0 +} From 1394f1fc2099fc9d8f1476f406e0346ef24075f0 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 09:56:23 +0200 Subject: [PATCH 03/30] Add context management and error handling to connection pool Introduced context support and enhanced error handling in connpool.go. Added detailed comments for better maintainability and introduced a wrapper for net.Conn to manage connection close behavior. The changes improve the robustness and clarity of the connection pool's operation. --- connpool.go | 134 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 110 insertions(+), 24 deletions(-) diff --git a/connpool.go b/connpool.go index 96ba915..23ed5d0 100644 --- a/connpool.go +++ b/connpool.go @@ -5,7 +5,9 @@ package mail import ( + "context" "errors" + "fmt" "net" "sync" ) @@ -15,7 +17,11 @@ import ( // concurrency template. var ( + // ErrPoolInvalidCap is returned when the connection pool's capacity settings are + // invalid (e.g., initial capacity is negative). ErrPoolInvalidCap = errors.New("invalid connection pool capacity settings") + // ErrClosed is returned when an operation is attempted on a closed connection pool. + ErrClosed = errors.New("connection pool is closed") ) // Pool interface describes a connection pool implementation. A Pool is @@ -36,12 +42,28 @@ type Pool interface { // connPool implements the Pool interface type connPool struct { - // mutex is used to synchronize access to the connection pool to ensure thread-safe operations + // mutex is used to synchronize access to the connection pool to ensure thread-safe operations. mutex sync.RWMutex - // conns is a channel used to manage and distribute net.Conn objects within the connection pool + // conns is a channel used to manage and distribute net.Conn objects within the connection pool. conns chan net.Conn - // dialCtx represents the actual net.Conn returned by the DialContextFunc - dialCtx DialContextFunc + + // dialCtxFunc represents the actual net.Conn returned by the DialContextFunc. + dialCtxFunc DialContextFunc + // dialContext is the context used for dialing new network connections within the connection pool. + dialContext context.Context + // dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in + // the connection pool. + dialNetwork string + // dialAddress specifies the address used to establish network connections within the connection pool. + dialAddress string +} + +// PoolConn is a wrapper around net.Conn to modify the the behavior of net.Conn's Close() method. +type PoolConn struct { + net.Conn + mutex sync.RWMutex + pool *connPool + unusable bool } // NewConnPool returns a new pool based on buffered channels with an initial @@ -50,40 +72,104 @@ type connPool struct { // fill the Pool until a new Get() is called. During a Get(), if there is no // new connection available in the pool, a new connection will be created via // the corresponding DialContextFunc() method. -func NewConnPool(initialCap, maxCap int, dialCtxFunc DialContextFunc) (Pool, error) { +func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc, + network, address string) (Pool, error) { if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { return nil, ErrPoolInvalidCap } pool := &connPool{ - conns: make(chan net.Conn, maxCap), - dialCtx: dialCtxFunc, + conns: make(chan net.Conn, maxCap), + dialCtxFunc: dialCtxFunc, + dialContext: ctx, + dialAddress: address, + dialNetwork: network, } - // create initial connections, if something goes wrong, - // just close the pool error out. + // Initial connections for the pool. Pool will be closed on connection error for i := 0; i < initialCap; i++ { - /* - conn, err := dialCtxFunc() - if err != nil { - pool.Close() - return nil, fmt.Errorf("factory is not able to fill the pool: %s", err) - } - c.conns <- conn + conn, err := dialCtxFunc(ctx, network, address) + if err != nil { + pool.Close() + return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %s", err) + } + pool.conns <- conn - */ } return pool, nil } -func (c *connPool) Get() (net.Conn, error) { - return nil, nil -} -func (c *connPool) Close() { - return +// Get satisfies the Get() method of the Pool inteface. If there is no new +// connection available in the Pool, a new connection will be created via the +// DialContextFunc() method. +func (p *connPool) Get() (net.Conn, error) { + ctx, conns, dialCtxFunc := p.getConnsAndDialContext() + if conns == nil { + return nil, ErrClosed + } + + // wrap the connections into the custom net.Conn implementation that puts + // connections back to the pool + select { + case <-ctx.Done(): + return nil, ctx.Err() + case conn := <-conns: + if conn == nil { + return nil, ErrClosed + } + return p.wrapConn(conn), nil + default: + conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress) + if err != nil { + return nil, err + } + return p.wrapConn(conn), nil + } } -func (c *connPool) Len() int { - return 0 +// Close terminates all connections in the pool and frees associated resources. Once closed, +// the pool is no longer usable. +func (p *connPool) Close() { + p.mutex.Lock() + conns := p.conns + p.conns = nil + p.dialCtxFunc = nil + p.dialContext = nil + p.dialAddress = "" + p.dialNetwork = "" + p.mutex.Unlock() + + if conns == nil { + return + } + + close(conns) + for conn := range conns { + _ = conn.Close() + } +} + +// Len returns the current number of connections in the connection pool. +func (p *connPool) Len() int { + _, conns, _ := p.getConnsAndDialContext() + return len(conns) +} + +// getConnsAndDialContext returns the connection channel and the DialContext function for the +// connection pool. +func (p *connPool) getConnsAndDialContext() (context.Context, chan net.Conn, DialContextFunc) { + p.mutex.RLock() + conns := p.conns + dialCtxFunc := p.dialCtxFunc + ctx := p.dialContext + p.mutex.RUnlock() + return ctx, conns, dialCtxFunc +} + +// wrapConn wraps a given net.Conn with a PoolConn, modifying the net.Conn's Close() method. +func (p *connPool) wrapConn(conn net.Conn) net.Conn { + poolconn := &PoolConn{pool: p} + poolconn.Conn = conn + return poolconn } From d6e5034bba6b439353ccde76899649996ced08c5 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 10:09:38 +0200 Subject: [PATCH 04/30] Add new error handling and connection management in connpool Introduce ErrClosed and ErrNilConn errors for better error handling. Implement Close and MarkUnusable methods for improved connection lifecycle management. Add put method to return connections to the pool or close them if necessary. --- connpool.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/connpool.go b/connpool.go index 23ed5d0..b22019c 100644 --- a/connpool.go +++ b/connpool.go @@ -17,11 +17,13 @@ import ( // concurrency template. var ( + // ErrClosed is returned when an operation is attempted on a closed connection pool. + ErrClosed = errors.New("connection pool is closed") + // ErrNilConn is returned when a nil connection is passed back to the connection pool. + ErrNilConn = errors.New("connection is nil") // ErrPoolInvalidCap is returned when the connection pool's capacity settings are // invalid (e.g., initial capacity is negative). ErrPoolInvalidCap = errors.New("invalid connection pool capacity settings") - // ErrClosed is returned when an operation is attempted on a closed connection pool. - ErrClosed = errors.New("connection pool is closed") ) // Pool interface describes a connection pool implementation. A Pool is @@ -66,6 +68,28 @@ type PoolConn struct { unusable bool } +// Close puts a given pool connection back to the pool instead of closing it. +func (c *PoolConn) Close() error { + c.mutex.RLock() + defer c.mutex.RUnlock() + + if c.unusable { + if c.Conn != nil { + return c.Conn.Close() + } + return nil + } + return c.pool.put(c.Conn) +} + +// MarkUnusable marks the connection not usable any more, to let the pool close it instead +// of returning it to pool. +func (c *PoolConn) MarkUnusable() { + c.mutex.Lock() + c.unusable = true + c.mutex.Unlock() +} + // NewConnPool returns a new pool based on buffered channels with an initial // capacity and maximum capacity. The DialContextFunc is used when the initial // capacity is greater than zero to fill the pool. A zero initialCap doesn't @@ -167,6 +191,28 @@ func (p *connPool) getConnsAndDialContext() (context.Context, chan net.Conn, Dia return ctx, conns, dialCtxFunc } +// put puts a passed connection back into the pool. If the pool is full or closed, +// conn is simply closed. A nil conn will be rejected with an error. +func (p *connPool) put(conn net.Conn) error { + if conn == nil { + return ErrNilConn + } + + p.mutex.RLock() + defer p.mutex.RUnlock() + + if p.conns == nil { + return conn.Close() + } + + select { + case p.conns <- conn: + return nil + default: + return conn.Close() + } +} + // wrapConn wraps a given net.Conn with a PoolConn, modifying the net.Conn's Close() method. func (p *connPool) wrapConn(conn net.Conn) net.Conn { poolconn := &PoolConn{pool: p} From 33d4eb5b21342dfacd99f12f2a6da2b8b36d362f Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 10:33:06 +0200 Subject: [PATCH 05/30] Add unit tests for connection pool and rename Len to Size Introduced unit tests for the connection pool to ensure robust functionality. Also, renamed the Len method to Size in the Pool interface and its implementation for better clarity and consistency. --- connpool.go | 8 +++---- connpool_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 connpool_test.go diff --git a/connpool.go b/connpool.go index b22019c..f774806 100644 --- a/connpool.go +++ b/connpool.go @@ -38,8 +38,8 @@ type Pool interface { // no longer usable. Close() - // Len returns the current number of connections of the pool. - Len() int + // Size returns the current number of connections of the pool. + Size() int } // connPool implements the Pool interface @@ -174,8 +174,8 @@ func (p *connPool) Close() { } } -// Len returns the current number of connections in the connection pool. -func (p *connPool) Len() int { +// Size returns the current number of connections in the connection pool. +func (p *connPool) Size() int { _, conns, _ := p.getConnsAndDialContext() return len(conns) } diff --git a/connpool_test.go b/connpool_test.go new file mode 100644 index 0000000..589ad68 --- /dev/null +++ b/connpool_test.go @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors +// +// SPDX-License-Identifier: MIT + +package mail + +import ( + "context" + "fmt" + "net" + "testing" + "time" +) + +func TestNewConnPool(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 10 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + + pool, err := newConnPool(serverPort) + if err != nil { + t.Errorf("failed to create connection pool: %s", err) + } + if pool == nil { + t.Errorf("connection pool is nil") + return + } + if pool.Size() != 5 { + t.Errorf("expected 5 connections, got %d", pool.Size()) + } + for i := 0; i < 5; i++ { + go func() { + conn, err := pool.Get() + if err != nil { + t.Errorf("failed to get connection: %s", err) + } + if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } + }() + } +} + +func newConnPool(port int) (Pool, error) { + netDialer := net.Dialer{} + return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp", + fmt.Sprintf("127.0.0.1:%d", port)) +} From 9a9e0c936d23bca6f1789f2add62d4a3666452b6 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 11:09:03 +0200 Subject: [PATCH 06/30] Remove redundant error handling in test code The check for io.EOF and the associated print statement were unnecessary because the loop breaks on any error. This change simplifies the error handling logic in the `client_test.go` file and avoids redundant code. --- client_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/client_test.go b/client_test.go index 3d50a2d..fdc486a 100644 --- a/client_test.go +++ b/client_test.go @@ -2100,10 +2100,6 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese for { data, err = reader.ReadString('\n') if err != nil { - if errors.Is(err, io.EOF) { - break - } - fmt.Println("Error reading data:", err) break } From 2cbd0c4aefa628093584e9b608044c0f58e9373e Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 11:09:11 +0200 Subject: [PATCH 07/30] Add test cases for connection pool functionality Added new test cases `TestConnPool_Get_Type` and `TestConnPool_Get` to verify connection pool operations. These tests ensure proper connection type and handling of pool size after connection retrieval and usage. --- connpool_test.go | 102 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 98 insertions(+), 4 deletions(-) diff --git a/connpool_test.go b/connpool_test.go index 589ad68..67a988f 100644 --- a/connpool_test.go +++ b/connpool_test.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "net" + "sync" "testing" "time" ) @@ -30,6 +31,7 @@ func TestNewConnPool(t *testing.T) { if err != nil { t.Errorf("failed to create connection pool: %s", err) } + defer pool.Close() if pool == nil { t.Errorf("connection pool is nil") return @@ -37,17 +39,109 @@ func TestNewConnPool(t *testing.T) { if pool.Size() != 5 { t.Errorf("expected 5 connections, got %d", pool.Size()) } - for i := 0; i < 5; i++ { + conn, err := pool.Get() + if err != nil { + t.Errorf("failed to get connection: %s", err) + } + if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } +} + +func TestConnPool_Get_Type(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 11 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + + pool, err := newConnPool(serverPort) + if err != nil { + t.Errorf("failed to create connection pool: %s", err) + } + defer pool.Close() + + conn, err := pool.Get() + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + return + } + + _, ok := conn.(*PoolConn) + if !ok { + t.Error("received connection from pool is not of type PoolConn") + return + } + if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } +} + +func TestConnPool_Get(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 12 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + + p, _ := newConnPool(serverPort) + defer p.Close() + + conn, err := p.Get() + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + return + } + if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } + + if p.Size() != 4 { + t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size()) + } + + var wg sync.WaitGroup + for i := 0; i < 4; i++ { + wg.Add(1) go func() { - conn, err := pool.Get() + defer wg.Done() + wgconn, err := p.Get() if err != nil { - t.Errorf("failed to get connection: %s", err) + t.Errorf("failed to get new connection from pool: %s", err) } - if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { t.Errorf("failed to write quit command to first connection: %s", err) } }() } + wg.Wait() + + if p.Size() != 0 { + t.Errorf("Get error. Expecting 0, got %d", p.Size()) + } + + conn, err = p.Get() + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + } + if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } + p.Close() } func newConnPool(port int) (Pool, error) { From f1188bdad7d2a494597e96b5d651b07e8c522a15 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 11:15:56 +0200 Subject: [PATCH 08/30] Add test for PoolConn Close method This commit introduces a new test, `TestPoolConn_Close`, to verify that connections are correctly closed and returned to the pool. It sets up a simple SMTP server, creates a connection pool, tests writing to and closing connections, and checks the pool size to ensure proper behavior. --- connpool_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/connpool_test.go b/connpool_test.go index 67a988f..536dffc 100644 --- a/connpool_test.go +++ b/connpool_test.go @@ -144,6 +144,59 @@ func TestConnPool_Get(t *testing.T) { p.Close() } +func TestPoolConn_Close(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 13 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + + netDialer := net.Dialer{} + p, err := NewConnPool(context.Background(), 0, 30, netDialer.DialContext, "tcp", + fmt.Sprintf("127.0.0.1:%d", serverPort)) + if err != nil { + t.Errorf("failed to create connection pool: %s", err) + } + defer p.Close() + + conns := make([]net.Conn, 30) + for i := 0; i < 30; i++ { + conn, _ := p.Get() + if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } + conns[i] = conn + } + for _, conn := range conns { + conn.Close() + } + + if p.Size() != 30 { + t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size()) + } + + conn, err := p.Get() + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + } + if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } + p.Close() + + conn.Close() + if p.Size() != 0 { + t.Errorf("closed pool shouldn't allow to put connections.") + } +} + func newConnPool(port int) (Pool, error) { netDialer := net.Dialer{} return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp", From 774925078a7af5342186960eddb0c60ea1000cdd Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 11:17:58 +0200 Subject: [PATCH 09/30] Improve error handling in connection pool tests Add error checks to Close() calls in connpool_test.go to ensure connection closures are handled properly, with descriptive error messages. Update comment in connpool.go to improve clarity on the source of code inspiration. --- connpool.go | 5 ++--- connpool_test.go | 8 ++++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/connpool.go b/connpool.go index f774806..d875793 100644 --- a/connpool.go +++ b/connpool.go @@ -12,9 +12,8 @@ import ( "sync" ) -// Parts of the connection pool code is forked from https://github.com/fatih/pool/ -// Thanks to Fatih Arslan and the project contributors for providing this great -// concurrency template. +// Parts of the connection pool code is forked/took inspiration from https://github.com/fatih/pool/ +// Thanks to Fatih Arslan and the project contributors for providing this great concurrency template. var ( // ErrClosed is returned when an operation is attempted on a closed connection pool. diff --git a/connpool_test.go b/connpool_test.go index 536dffc..4baf084 100644 --- a/connpool_test.go +++ b/connpool_test.go @@ -175,7 +175,9 @@ func TestPoolConn_Close(t *testing.T) { conns[i] = conn } for _, conn := range conns { - conn.Close() + if err = conn.Close(); err != nil { + t.Errorf("failed to close connection: %s", err) + } } if p.Size() != 30 { @@ -191,7 +193,9 @@ func TestPoolConn_Close(t *testing.T) { } p.Close() - conn.Close() + if err = conn.Close(); err != nil { + t.Errorf("failed to close connection: %s", err) + } if p.Size() != 0 { t.Errorf("closed pool shouldn't allow to put connections.") } From 5503be8451fae20a55db907b174c5dde1fbc6b84 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 11:26:54 +0200 Subject: [PATCH 10/30] Remove redundant error print statements Removed redundant fmt.Printf error print statements for connection read and write errors. This cleans up the test output and makes error handling more streamlined. --- client_test.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/client_test.go b/client_test.go index fdc486a..c891a10 100644 --- a/client_test.go +++ b/client_test.go @@ -2085,7 +2085,6 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese data, err := reader.ReadString('\n') if err != nil { - fmt.Printf("unable to read from connection: %s\n", err) return } if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") { @@ -2093,7 +2092,6 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese return } if err = writeLine("250-localhost.localdomain\r\n" + featureSet); err != nil { - fmt.Printf("unable to write to connection: %s\n", err) return } From 2abdee743d5b3fbfad694574259a4a49bf6849e7 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 11:27:05 +0200 Subject: [PATCH 11/30] Add unit test for marking connection as unusable Introduces `TestPoolConn_MarkUnusable` to ensure the pool maintains its integrity when a connection is marked unusable. This test validates that the connection pool size adjusts correctly after marking a connection as unusable and closing it. --- connpool_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/connpool_test.go b/connpool_test.go index 4baf084..d03062f 100644 --- a/connpool_test.go +++ b/connpool_test.go @@ -201,6 +201,60 @@ func TestPoolConn_Close(t *testing.T) { } } +func TestPoolConn_MarkUnusable(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 14 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + + pool, _ := newConnPool(serverPort) + defer pool.Close() + + conn, err := pool.Get() + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + } + if err = conn.Close(); err != nil { + t.Errorf("failed to close connection: %s", err) + } + + poolSize := pool.Size() + conn, err = pool.Get() + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + } + if err = conn.Close(); err != nil { + t.Errorf("failed to close connection: %s", err) + } + if pool.Size() != poolSize { + t.Errorf("pool size is expected to be equal to initial size") + } + + conn, err = pool.Get() + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + } + if pc, ok := conn.(*PoolConn); !ok { + t.Errorf("this should never happen") + } else { + pc.MarkUnusable() + } + if err = conn.Close(); err != nil { + t.Errorf("failed to close connection: %s", err) + } + if pool.Size() != poolSize-1 { + t.Errorf("pool size is expected to be: %d but got: %d", poolSize-1, pool.Size()) + } +} + func newConnPool(port int) (Pool, error) { netDialer := net.Dialer{} return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp", From 07e7b17ae8a073feda2288cf38e8d865975e3a89 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 11:45:22 +0200 Subject: [PATCH 12/30] Add tests for ConnPool close and concurrency issues This commit introduces two new tests: `TestConnPool_Close` and `TestConnPool_Concurrency`. The former ensures the proper closing of connection pool resources, while the latter checks for concurrency issues by creating and closing multiple connections in parallel. --- connpool_test.go | 100 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/connpool_test.go b/connpool_test.go index d03062f..4450e8c 100644 --- a/connpool_test.go +++ b/connpool_test.go @@ -255,6 +255,106 @@ func TestPoolConn_MarkUnusable(t *testing.T) { } } +func TestConnPool_Close(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 15 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + + pool, err := newConnPool(serverPort) + if err != nil { + t.Errorf("failed to create connection pool: %s", err) + } + pool.Close() + + castPool := pool.(*connPool) + + if castPool.conns != nil { + t.Error("closing pool failed: conns channel should be nil") + } + if castPool.dialCtxFunc != nil { + t.Error("closing pool failed: dialCtxFunc should be nil") + } + if castPool.dialContext != nil { + t.Error("closing pool failed: dialContext should be nil") + } + if castPool.dialAddress != "" { + t.Error("closing pool failed: dialAddress should be empty") + } + if castPool.dialNetwork != "" { + t.Error("closing pool failed: dialNetwork should be empty") + } + + conn, err := pool.Get() + if err == nil { + t.Errorf("closing pool failed: getting new connection should return an error") + } + if conn != nil { + t.Errorf("closing pool failed: getting new connection should return a nil-connection") + } + if pool.Size() != 0 { + t.Errorf("closing pool failed: pool size should be 0, but got: %d", pool.Size()) + } +} + +func TestConnPool_Concurrency(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 16 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + + pool, err := newConnPool(serverPort) + if err != nil { + t.Errorf("failed to create connection pool: %s", err) + } + defer pool.Close() + pipe := make(chan net.Conn) + + getWg := sync.WaitGroup{} + closeWg := sync.WaitGroup{} + for i := 0; i < 30; i++ { + getWg.Add(1) + closeWg.Add(1) + go func() { + conn, err := pool.Get() + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + } + pipe <- conn + getWg.Done() + }() + + go func() { + conn := <-pipe + if conn == nil { + return + } + if err = conn.Close(); err != nil { + t.Errorf("failed to close connection: %s", err) + } + closeWg.Done() + }() + getWg.Wait() + closeWg.Wait() + } +} + func newConnPool(port int) (Pool, error) { netDialer := net.Dialer{} return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp", From c8684886ed23d2a42b1ab178400c7a7633650e47 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 13:44:03 +0200 Subject: [PATCH 13/30] Refactor Get method to include context argument Updated the Get method in connpool.go and its usage in tests to include a context argument for better cancellation and timeout handling. Removed the redundant dialContext field from the connection pool struct and added a new test to validate context timeout behavior. --- connpool.go | 27 ++++++-------- connpool_test.go | 91 ++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 87 insertions(+), 31 deletions(-) diff --git a/connpool.go b/connpool.go index d875793..f50e73d 100644 --- a/connpool.go +++ b/connpool.go @@ -31,7 +31,7 @@ type Pool interface { // Get returns a new connection from the pool. Closing the connections returns // it back into the Pool. Closing a connection when the Pool is destroyed or // full will be counted as an error. - Get() (net.Conn, error) + Get(ctx context.Context) (net.Conn, error) // Close closes the pool and all its connections. After Close() the pool is // no longer usable. @@ -50,8 +50,6 @@ type connPool struct { // dialCtxFunc represents the actual net.Conn returned by the DialContextFunc. dialCtxFunc DialContextFunc - // dialContext is the context used for dialing new network connections within the connection pool. - dialContext context.Context // dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in // the connection pool. dialNetwork string @@ -96,7 +94,8 @@ func (c *PoolConn) MarkUnusable() { // new connection available in the pool, a new connection will be created via // the corresponding DialContextFunc() method. func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc, - network, address string) (Pool, error) { + network, address string, +) (Pool, error) { if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { return nil, ErrPoolInvalidCap } @@ -104,7 +103,6 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo pool := &connPool{ conns: make(chan net.Conn, maxCap), dialCtxFunc: dialCtxFunc, - dialContext: ctx, dialAddress: address, dialNetwork: network, } @@ -114,10 +112,9 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo conn, err := dialCtxFunc(ctx, network, address) if err != nil { pool.Close() - return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %s", err) + return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %w", err) } pool.conns <- conn - } return pool, nil @@ -126,8 +123,8 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo // Get satisfies the Get() method of the Pool inteface. If there is no new // connection available in the Pool, a new connection will be created via the // DialContextFunc() method. -func (p *connPool) Get() (net.Conn, error) { - ctx, conns, dialCtxFunc := p.getConnsAndDialContext() +func (p *connPool) Get(ctx context.Context) (net.Conn, error) { + conns, dialCtxFunc := p.getConnsAndDialContext() if conns == nil { return nil, ErrClosed } @@ -136,7 +133,7 @@ func (p *connPool) Get() (net.Conn, error) { // connections back to the pool select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, fmt.Errorf("failed to get connection: %w", ctx.Err()) case conn := <-conns: if conn == nil { return nil, ErrClosed @@ -145,7 +142,7 @@ func (p *connPool) Get() (net.Conn, error) { default: conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress) if err != nil { - return nil, err + return nil, fmt.Errorf("dialContextFunc failed: %w", err) } return p.wrapConn(conn), nil } @@ -158,7 +155,6 @@ func (p *connPool) Close() { conns := p.conns p.conns = nil p.dialCtxFunc = nil - p.dialContext = nil p.dialAddress = "" p.dialNetwork = "" p.mutex.Unlock() @@ -175,19 +171,18 @@ func (p *connPool) Close() { // Size returns the current number of connections in the connection pool. func (p *connPool) Size() int { - _, conns, _ := p.getConnsAndDialContext() + conns, _ := p.getConnsAndDialContext() return len(conns) } // getConnsAndDialContext returns the connection channel and the DialContext function for the // connection pool. -func (p *connPool) getConnsAndDialContext() (context.Context, chan net.Conn, DialContextFunc) { +func (p *connPool) getConnsAndDialContext() (chan net.Conn, DialContextFunc) { p.mutex.RLock() conns := p.conns dialCtxFunc := p.dialCtxFunc - ctx := p.dialContext p.mutex.RUnlock() - return ctx, conns, dialCtxFunc + return conns, dialCtxFunc } // put puts a passed connection back into the pool. If the pool is full or closed, diff --git a/connpool_test.go b/connpool_test.go index 4450e8c..1f42ff6 100644 --- a/connpool_test.go +++ b/connpool_test.go @@ -39,7 +39,7 @@ func TestNewConnPool(t *testing.T) { if pool.Size() != 5 { t.Errorf("expected 5 connections, got %d", pool.Size()) } - conn, err := pool.Get() + conn, err := pool.Get(context.Background()) if err != nil { t.Errorf("failed to get connection: %s", err) } @@ -68,7 +68,7 @@ func TestConnPool_Get_Type(t *testing.T) { } defer pool.Close() - conn, err := pool.Get() + conn, err := pool.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) return @@ -101,7 +101,7 @@ func TestConnPool_Get(t *testing.T) { p, _ := newConnPool(serverPort) defer p.Close() - conn, err := p.Get() + conn, err := p.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) return @@ -119,7 +119,7 @@ func TestConnPool_Get(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - wgconn, err := p.Get() + wgconn, err := p.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -134,7 +134,7 @@ func TestConnPool_Get(t *testing.T) { t.Errorf("Get error. Expecting 0, got %d", p.Size()) } - conn, err = p.Get() + conn, err = p.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -168,7 +168,7 @@ func TestPoolConn_Close(t *testing.T) { conns := make([]net.Conn, 30) for i := 0; i < 30; i++ { - conn, _ := p.Get() + conn, _ := p.Get(context.Background()) if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { t.Errorf("failed to write quit command to first connection: %s", err) } @@ -184,7 +184,7 @@ func TestPoolConn_Close(t *testing.T) { t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size()) } - conn, err := p.Get() + conn, err := p.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -218,7 +218,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) { pool, _ := newConnPool(serverPort) defer pool.Close() - conn, err := pool.Get() + conn, err := pool.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -227,7 +227,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) { } poolSize := pool.Size() - conn, err = pool.Get() + conn, err = pool.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -238,7 +238,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) { t.Errorf("pool size is expected to be equal to initial size") } - conn, err = pool.Get() + conn, err = pool.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -283,9 +283,6 @@ func TestConnPool_Close(t *testing.T) { if castPool.dialCtxFunc != nil { t.Error("closing pool failed: dialCtxFunc should be nil") } - if castPool.dialContext != nil { - t.Error("closing pool failed: dialContext should be nil") - } if castPool.dialAddress != "" { t.Error("closing pool failed: dialAddress should be empty") } @@ -293,7 +290,7 @@ func TestConnPool_Close(t *testing.T) { t.Error("closing pool failed: dialNetwork should be empty") } - conn, err := pool.Get() + conn, err := pool.Get(context.Background()) if err == nil { t.Errorf("closing pool failed: getting new connection should return an error") } @@ -332,7 +329,7 @@ func TestConnPool_Concurrency(t *testing.T) { getWg.Add(1) closeWg.Add(1) go func() { - conn, err := pool.Get() + conn, err := pool.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -355,6 +352,70 @@ func TestConnPool_Concurrency(t *testing.T) { } } +func TestConnPool_GetContextTimeout(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 17 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + + p, err := newConnPool(serverPort) + if err != nil { + t.Errorf("failed to create connection pool: %s", err) + } + defer p.Close() + + connCtx, connCancel := context.WithCancel(context.Background()) + defer connCancel() + + conn, err := p.Get(connCtx) + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + return + } + if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } + + if p.Size() != 4 { + t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size()) + } + + var wg sync.WaitGroup + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + wgconn, err := p.Get(connCtx) + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + } + if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } + }() + } + wg.Wait() + + if p.Size() != 0 { + t.Errorf("Get error. Expecting 0, got %d", p.Size()) + } + + connCancel() + _, err = p.Get(connCtx) + if err == nil { + t.Errorf("getting new connection on canceled context should fail, but didn't") + } + p.Close() +} + func newConnPool(port int) (Pool, error) { netDialer := net.Dialer{} return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp", From 4f6224131ef450d62ece29e47dc1fd0b6ba8c52d Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 13:46:41 +0200 Subject: [PATCH 14/30] Rename test for accurate context cancellation Updated the test name from `TestConnPool_GetContextTimeout` to `TestConnPool_GetContextCancel` to better reflect its functionality. This change improves test readability and maintains consistency with the context usage in the test. --- connpool_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connpool_test.go b/connpool_test.go index 1f42ff6..42b26c8 100644 --- a/connpool_test.go +++ b/connpool_test.go @@ -352,7 +352,7 @@ func TestConnPool_Concurrency(t *testing.T) { } } -func TestConnPool_GetContextTimeout(t *testing.T) { +func TestConnPool_GetContextCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() From fd115d5173d7811397ef32a1aac158a44b00c3fb Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 14:15:43 +0200 Subject: [PATCH 15/30] Remove typo from comment in smtp_ehlo_117.go Fixed a typo in the backward compatibility comment for Go 1.16/1.17 in smtp_ehlo_117.go. This ensures clarity and correctness in documentation. --- smtp/smtp_ehlo_117.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smtp/smtp_ehlo_117.go b/smtp/smtp_ehlo_117.go index 429f30a..c516a36 100644 --- a/smtp/smtp_ehlo_117.go +++ b/smtp/smtp_ehlo_117.go @@ -22,7 +22,7 @@ import "strings" // should be the preferred greeting for servers that support it. // // Backport of: https://github.com/golang/go/commit/4d8db00641cc9ff4f44de7df9b8c4f4a4f9416ee#diff-4f6f6bdb9891d4dd271f9f31430420a2e44018fe4ee539576faf458bebb3cee4 -// to guarantee backwards compatibility with Go 1.16/1.17:w +// to guarantee backwards compatibility with Go 1.16/1.17 func (c *Client) ehlo() error { _, msg, err := c.cmd(250, "EHLO %s", c.localName) if err != nil { From 8683917c3dec1316b9349ca8546eeb9d5d40d3f6 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Thu, 26 Sep 2024 11:49:48 +0200 Subject: [PATCH 16/30] Delete connection pool implementation and tests Remove `connpool.go` and `connpool_test.go`. This eliminates the connection pool feature from the codebase, including associated functionality and tests. The connection pool feature is much to complex and doesn't provide the benefits expected by the concurrency feature --- connpool.go | 215 ------------------------ connpool_test.go | 423 ----------------------------------------------- 2 files changed, 638 deletions(-) delete mode 100644 connpool.go delete mode 100644 connpool_test.go diff --git a/connpool.go b/connpool.go deleted file mode 100644 index f50e73d..0000000 --- a/connpool.go +++ /dev/null @@ -1,215 +0,0 @@ -// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors -// -// SPDX-License-Identifier: MIT - -package mail - -import ( - "context" - "errors" - "fmt" - "net" - "sync" -) - -// Parts of the connection pool code is forked/took inspiration from https://github.com/fatih/pool/ -// Thanks to Fatih Arslan and the project contributors for providing this great concurrency template. - -var ( - // ErrClosed is returned when an operation is attempted on a closed connection pool. - ErrClosed = errors.New("connection pool is closed") - // ErrNilConn is returned when a nil connection is passed back to the connection pool. - ErrNilConn = errors.New("connection is nil") - // ErrPoolInvalidCap is returned when the connection pool's capacity settings are - // invalid (e.g., initial capacity is negative). - ErrPoolInvalidCap = errors.New("invalid connection pool capacity settings") -) - -// Pool interface describes a connection pool implementation. A Pool is -// thread-/go-routine safe. -type Pool interface { - // Get returns a new connection from the pool. Closing the connections returns - // it back into the Pool. Closing a connection when the Pool is destroyed or - // full will be counted as an error. - Get(ctx context.Context) (net.Conn, error) - - // Close closes the pool and all its connections. After Close() the pool is - // no longer usable. - Close() - - // Size returns the current number of connections of the pool. - Size() int -} - -// connPool implements the Pool interface -type connPool struct { - // mutex is used to synchronize access to the connection pool to ensure thread-safe operations. - mutex sync.RWMutex - // conns is a channel used to manage and distribute net.Conn objects within the connection pool. - conns chan net.Conn - - // dialCtxFunc represents the actual net.Conn returned by the DialContextFunc. - dialCtxFunc DialContextFunc - // dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in - // the connection pool. - dialNetwork string - // dialAddress specifies the address used to establish network connections within the connection pool. - dialAddress string -} - -// PoolConn is a wrapper around net.Conn to modify the the behavior of net.Conn's Close() method. -type PoolConn struct { - net.Conn - mutex sync.RWMutex - pool *connPool - unusable bool -} - -// Close puts a given pool connection back to the pool instead of closing it. -func (c *PoolConn) Close() error { - c.mutex.RLock() - defer c.mutex.RUnlock() - - if c.unusable { - if c.Conn != nil { - return c.Conn.Close() - } - return nil - } - return c.pool.put(c.Conn) -} - -// MarkUnusable marks the connection not usable any more, to let the pool close it instead -// of returning it to pool. -func (c *PoolConn) MarkUnusable() { - c.mutex.Lock() - c.unusable = true - c.mutex.Unlock() -} - -// NewConnPool returns a new pool based on buffered channels with an initial -// capacity and maximum capacity. The DialContextFunc is used when the initial -// capacity is greater than zero to fill the pool. A zero initialCap doesn't -// fill the Pool until a new Get() is called. During a Get(), if there is no -// new connection available in the pool, a new connection will be created via -// the corresponding DialContextFunc() method. -func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc, - network, address string, -) (Pool, error) { - if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { - return nil, ErrPoolInvalidCap - } - - pool := &connPool{ - conns: make(chan net.Conn, maxCap), - dialCtxFunc: dialCtxFunc, - dialAddress: address, - dialNetwork: network, - } - - // Initial connections for the pool. Pool will be closed on connection error - for i := 0; i < initialCap; i++ { - conn, err := dialCtxFunc(ctx, network, address) - if err != nil { - pool.Close() - return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %w", err) - } - pool.conns <- conn - } - - return pool, nil -} - -// Get satisfies the Get() method of the Pool inteface. If there is no new -// connection available in the Pool, a new connection will be created via the -// DialContextFunc() method. -func (p *connPool) Get(ctx context.Context) (net.Conn, error) { - conns, dialCtxFunc := p.getConnsAndDialContext() - if conns == nil { - return nil, ErrClosed - } - - // wrap the connections into the custom net.Conn implementation that puts - // connections back to the pool - select { - case <-ctx.Done(): - return nil, fmt.Errorf("failed to get connection: %w", ctx.Err()) - case conn := <-conns: - if conn == nil { - return nil, ErrClosed - } - return p.wrapConn(conn), nil - default: - conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress) - if err != nil { - return nil, fmt.Errorf("dialContextFunc failed: %w", err) - } - return p.wrapConn(conn), nil - } -} - -// Close terminates all connections in the pool and frees associated resources. Once closed, -// the pool is no longer usable. -func (p *connPool) Close() { - p.mutex.Lock() - conns := p.conns - p.conns = nil - p.dialCtxFunc = nil - p.dialAddress = "" - p.dialNetwork = "" - p.mutex.Unlock() - - if conns == nil { - return - } - - close(conns) - for conn := range conns { - _ = conn.Close() - } -} - -// Size returns the current number of connections in the connection pool. -func (p *connPool) Size() int { - conns, _ := p.getConnsAndDialContext() - return len(conns) -} - -// getConnsAndDialContext returns the connection channel and the DialContext function for the -// connection pool. -func (p *connPool) getConnsAndDialContext() (chan net.Conn, DialContextFunc) { - p.mutex.RLock() - conns := p.conns - dialCtxFunc := p.dialCtxFunc - p.mutex.RUnlock() - return conns, dialCtxFunc -} - -// put puts a passed connection back into the pool. If the pool is full or closed, -// conn is simply closed. A nil conn will be rejected with an error. -func (p *connPool) put(conn net.Conn) error { - if conn == nil { - return ErrNilConn - } - - p.mutex.RLock() - defer p.mutex.RUnlock() - - if p.conns == nil { - return conn.Close() - } - - select { - case p.conns <- conn: - return nil - default: - return conn.Close() - } -} - -// wrapConn wraps a given net.Conn with a PoolConn, modifying the net.Conn's Close() method. -func (p *connPool) wrapConn(conn net.Conn) net.Conn { - poolconn := &PoolConn{pool: p} - poolconn.Conn = conn - return poolconn -} diff --git a/connpool_test.go b/connpool_test.go deleted file mode 100644 index 42b26c8..0000000 --- a/connpool_test.go +++ /dev/null @@ -1,423 +0,0 @@ -// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors -// -// SPDX-License-Identifier: MIT - -package mail - -import ( - "context" - "fmt" - "net" - "sync" - "testing" - "time" -) - -func TestNewConnPool(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 10 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - pool, err := newConnPool(serverPort) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - defer pool.Close() - if pool == nil { - t.Errorf("connection pool is nil") - return - } - if pool.Size() != 5 { - t.Errorf("expected 5 connections, got %d", pool.Size()) - } - conn, err := pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get connection: %s", err) - } - if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } -} - -func TestConnPool_Get_Type(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 11 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - pool, err := newConnPool(serverPort) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - defer pool.Close() - - conn, err := pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - return - } - - _, ok := conn.(*PoolConn) - if !ok { - t.Error("received connection from pool is not of type PoolConn") - return - } - if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } -} - -func TestConnPool_Get(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 12 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - p, _ := newConnPool(serverPort) - defer p.Close() - - conn, err := p.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - return - } - if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - - if p.Size() != 4 { - t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size()) - } - - var wg sync.WaitGroup - for i := 0; i < 4; i++ { - wg.Add(1) - go func() { - defer wg.Done() - wgconn, err := p.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - }() - } - wg.Wait() - - if p.Size() != 0 { - t.Errorf("Get error. Expecting 0, got %d", p.Size()) - } - - conn, err = p.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - p.Close() -} - -func TestPoolConn_Close(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 13 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - netDialer := net.Dialer{} - p, err := NewConnPool(context.Background(), 0, 30, netDialer.DialContext, "tcp", - fmt.Sprintf("127.0.0.1:%d", serverPort)) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - defer p.Close() - - conns := make([]net.Conn, 30) - for i := 0; i < 30; i++ { - conn, _ := p.Get(context.Background()) - if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - conns[i] = conn - } - for _, conn := range conns { - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - } - - if p.Size() != 30 { - t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size()) - } - - conn, err := p.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - p.Close() - - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - if p.Size() != 0 { - t.Errorf("closed pool shouldn't allow to put connections.") - } -} - -func TestPoolConn_MarkUnusable(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 14 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - pool, _ := newConnPool(serverPort) - defer pool.Close() - - conn, err := pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - - poolSize := pool.Size() - conn, err = pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - if pool.Size() != poolSize { - t.Errorf("pool size is expected to be equal to initial size") - } - - conn, err = pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if pc, ok := conn.(*PoolConn); !ok { - t.Errorf("this should never happen") - } else { - pc.MarkUnusable() - } - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - if pool.Size() != poolSize-1 { - t.Errorf("pool size is expected to be: %d but got: %d", poolSize-1, pool.Size()) - } -} - -func TestConnPool_Close(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 15 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - pool, err := newConnPool(serverPort) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - pool.Close() - - castPool := pool.(*connPool) - - if castPool.conns != nil { - t.Error("closing pool failed: conns channel should be nil") - } - if castPool.dialCtxFunc != nil { - t.Error("closing pool failed: dialCtxFunc should be nil") - } - if castPool.dialAddress != "" { - t.Error("closing pool failed: dialAddress should be empty") - } - if castPool.dialNetwork != "" { - t.Error("closing pool failed: dialNetwork should be empty") - } - - conn, err := pool.Get(context.Background()) - if err == nil { - t.Errorf("closing pool failed: getting new connection should return an error") - } - if conn != nil { - t.Errorf("closing pool failed: getting new connection should return a nil-connection") - } - if pool.Size() != 0 { - t.Errorf("closing pool failed: pool size should be 0, but got: %d", pool.Size()) - } -} - -func TestConnPool_Concurrency(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 16 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - pool, err := newConnPool(serverPort) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - defer pool.Close() - pipe := make(chan net.Conn) - - getWg := sync.WaitGroup{} - closeWg := sync.WaitGroup{} - for i := 0; i < 30; i++ { - getWg.Add(1) - closeWg.Add(1) - go func() { - conn, err := pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - pipe <- conn - getWg.Done() - }() - - go func() { - conn := <-pipe - if conn == nil { - return - } - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - closeWg.Done() - }() - getWg.Wait() - closeWg.Wait() - } -} - -func TestConnPool_GetContextCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 17 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - p, err := newConnPool(serverPort) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - defer p.Close() - - connCtx, connCancel := context.WithCancel(context.Background()) - defer connCancel() - - conn, err := p.Get(connCtx) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - return - } - if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - - if p.Size() != 4 { - t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size()) - } - - var wg sync.WaitGroup - for i := 0; i < 4; i++ { - wg.Add(1) - go func() { - defer wg.Done() - wgconn, err := p.Get(connCtx) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - }() - } - wg.Wait() - - if p.Size() != 0 { - t.Errorf("Get error. Expecting 0, got %d", p.Size()) - } - - connCancel() - _, err = p.Get(connCtx) - if err == nil { - t.Errorf("getting new connection on canceled context should fail, but didn't") - } - p.Close() -} - -func newConnPool(port int) (Pool, error) { - netDialer := net.Dialer{} - return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp", - fmt.Sprintf("127.0.0.1:%d", port)) -} From 3871b2be44124484903057d4973ddf05488be071 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Thu, 26 Sep 2024 11:51:30 +0200 Subject: [PATCH 17/30] Lock client connections and update deadline handling Add mutex locking for client connections to ensure thread safety. Introduce `HasConnection` method to check active connections and `UpdateDeadline` method to handle timeout updates. Refactor connection handling in `checkConn` and `tls` methods accordingly. --- client.go | 18 +++++++++++------- client_120.go | 2 ++ smtp/smtp.go | 14 ++++++++++++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index fac9a34..77f9751 100644 --- a/client.go +++ b/client.go @@ -12,6 +12,7 @@ import ( "net" "os" "strings" + "sync" "time" "github.com/wneessen/go-mail/log" @@ -87,6 +88,7 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con // Client is the SMTP client struct type Client struct { + mutex sync.RWMutex // connection is the net.Conn that the smtp.Client is based on connection net.Conn @@ -589,6 +591,9 @@ func (c *Client) setDefaultHelo() error { // DialWithContext establishes a connection to the SMTP server with a given context.Context func (c *Client) DialWithContext(dialCtx context.Context) error { + c.mutex.Lock() + defer c.mutex.Unlock() + ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) defer cancel() @@ -602,17 +607,16 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { c.dialContextFunc = tlsDialer.DialContext } } - var err error - c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr()) + connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) if err != nil && c.fallbackPort != 0 { // TODO: should we somehow log or append the previous error? - c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) + connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) } if err != nil { return err } - client, err := smtp.NewClient(c.connection, c.host) + client, err := smtp.NewClient(connection, c.host) if err != nil { return err } @@ -691,7 +695,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e // checkConn makes sure that a required server connection is available and extends the // connection deadline func (c *Client) checkConn() error { - if c.connection == nil { + if !c.smtpClient.HasConnection() { return ErrNoActiveConnection } @@ -701,7 +705,7 @@ func (c *Client) checkConn() error { } } - if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil { + if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil { return ErrDeadlineExtendFailed } return nil @@ -715,7 +719,7 @@ func (c *Client) serverFallbackAddr() string { // tls tries to make sure that the STARTTLS requirements are satisfied func (c *Client) tls() error { - if c.connection == nil { + if !c.smtpClient.HasConnection() { return ErrNoActiveConnection } if !c.useSSL && c.tlspolicy != NoTLS { diff --git a/client_120.go b/client_120.go index 4f82aa7..729069b 100644 --- a/client_120.go +++ b/client_120.go @@ -13,6 +13,8 @@ import ( // Send sends out the mail message func (c *Client) Send(messages ...*Msg) (returnErr error) { + c.mutex.Lock() + defer c.mutex.Unlock() if err := c.checkConn(); err != nil { returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} return diff --git a/smtp/smtp.go b/smtp/smtp.go index d2a0e64..5f5484a 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -30,6 +30,7 @@ import ( "net/textproto" "os" "strings" + "time" "github.com/wneessen/go-mail/log" ) @@ -472,6 +473,19 @@ func (c *Client) SetDSNRcptNotifyOption(d string) { c.dsnrntype = d } +// HasConnection checks if the client has an active connection. +// Returns true if the `conn` field is not nil, indicating an active connection. +func (c *Client) HasConnection() bool { + return c.conn != nil +} + +func (c *Client) UpdateDeadline(timeout time.Duration) error { + if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return fmt.Errorf("smtp: failed to update deadline: %w", err) + } + return nil +} + // debugLog checks if the debug flag is set and if so logs the provided message to // the log.Logger interface func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) { From 371b950bc76884ca6b6700afad29f9a876ec3752 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 10:33:19 +0200 Subject: [PATCH 18/30] Refactor Client struct for better readability and organization Reordered and grouped fields in the Client struct for clarity. The reorganization separates logical groups of fields, making it easier to understand and maintain the code. This includes proper grouping of TLS parameters, DSN options, and debug settings. --- smtp/smtp.go | 55 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 19 deletions(-) diff --git a/smtp/smtp.go b/smtp/smtp.go index 5f5484a..787006c 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -37,28 +37,45 @@ import ( // A Client represents a client connection to an SMTP server. type Client struct { - // Text is the textproto.Conn used by the Client. It is exported to allow for - // clients to add extensions. + // Text is the textproto.Conn used by the Client. It is exported to allow for clients to add extensions. Text *textproto.Conn - // keep a reference to the connection so it can be used to create a TLS - // connection later + + // auth supported auth mechanisms + auth []string + + // keep a reference to the connection so it can be used to create a TLS connection later conn net.Conn - // whether the Client is using TLS - tls bool - serverName string - // map of supported extensions + + // debug logging is enabled + debug bool + + // didHello indicates whether we've said HELO/EHLO + didHello bool + + // dsnmrtype defines the mail return option in case DSN is enabled + dsnmrtype string + + // dsnrntype defines the recipient notify option in case DSN is enabled + dsnrntype string + + // ext is a map of supported extensions ext map[string]string - // supported auth mechanisms - auth []string - localName string // the name to use in HELO/EHLO - didHello bool // whether we've said HELO/EHLO - helloError error // the error from the hello - // debug logging - debug bool // debug logging is enabled - logger log.Logger // logger will be used for debug logging - // DSN support - dsnmrtype string // dsnmrtype defines the mail return option in case DSN is enabled - dsnrntype string // dsnrntype defines the recipient notify option in case DSN is enabled + + // helloError is the error from the hello + helloError error + + // localName is the name to use in HELO/EHLO + localName string // the name to use in HELO/EHLO + + // logger will be used for debug logging + logger log.Logger + + // tls indicates whether the Client is using TLS + tls bool + + // serverName denotes the name of the server to which the application will connect. Used for + // identification and routing. + serverName string } // Dial returns a new [Client] connected to an SMTP server at addr. From 23c71d608f5540da3120176e004448cf833241e6 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 10:33:28 +0200 Subject: [PATCH 19/30] Lock mutex before checking connection in Send method Added mutex locking in the `Send` method for both `client_120.go` and `client_119.go`. This ensures thread-safe access to the connection checks and prevents potential race conditions. --- client_119.go | 3 +++ client_120.go | 1 + 2 files changed, 4 insertions(+) diff --git a/client_119.go b/client_119.go index 7de5d59..0b05061 100644 --- a/client_119.go +++ b/client_119.go @@ -11,6 +11,9 @@ import "errors" // Send sends out the mail message func (c *Client) Send(messages ...*Msg) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if err := c.checkConn(); err != nil { return &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} } diff --git a/client_120.go b/client_120.go index 729069b..5bb291d 100644 --- a/client_120.go +++ b/client_120.go @@ -15,6 +15,7 @@ import ( func (c *Client) Send(messages ...*Msg) (returnErr error) { c.mutex.Lock() defer c.mutex.Unlock() + if err := c.checkConn(); err != nil { returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} return From 2084526c772f02a3e29ce5dac3f69c5b10f4afdf Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 10:36:09 +0200 Subject: [PATCH 20/30] Refactor Client struct to improve organization and clarity Rearranged and grouped struct fields more logically within Client. Introduced the dialContextFunc and fallbackPort fields to enhance connection flexibility. Minor code style adjustments were also made for better readability. --- client.go | 44 +++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 77f9751..d898708 100644 --- a/client.go +++ b/client.go @@ -88,13 +88,15 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con // Client is the SMTP client struct type Client struct { - mutex sync.RWMutex // connection is the net.Conn that the smtp.Client is based on connection net.Conn // Timeout for the SMTP server connection connTimeout time.Duration + // dialContextFunc is a custom DialContext function to dial target SMTP server + dialContextFunc DialContextFunc + // dsn indicates that we want to use DSN for the Client dsn bool @@ -104,11 +106,9 @@ type Client struct { // dsnrntype defines the DSNRcptNotifyOption in case DSN is enabled dsnrntype []string - // isEncrypted indicates if a Client connection is encrypted or not - isEncrypted bool - - // noNoop indicates the Noop is to be skipped - noNoop bool + // fallbackPort is used as an alternative port number in case the primary port is unavailable or + // fails to bind. + fallbackPort int // HELO/EHLO string for the greeting the target SMTP server helo string @@ -116,12 +116,24 @@ type Client struct { // Hostname of the target SMTP server to connect to host string + // isEncrypted indicates if a Client connection is encrypted or not + isEncrypted bool + + // logger is a logger that implements the log.Logger interface + logger log.Logger + + // mutex is used to synchronize access to shared resources, ensuring that only one goroutine can + // modify them at a time. + mutex sync.RWMutex + + // noNoop indicates the Noop is to be skipped + noNoop bool + // pass is the corresponding SMTP AUTH password pass string - // Port of the SMTP server to connect to - port int - fallbackPort int + // port specifies the network port number on which the server listens for incoming connections. + port int // smtpAuth is a pointer to smtp.Auth smtpAuth smtp.Auth @@ -132,26 +144,20 @@ type Client struct { // smtpClient is the smtp.Client that is set up when using the Dial*() methods smtpClient *smtp.Client - // Use SSL for the connection - useSSL bool - // tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol tlspolicy TLSPolicy // tlsconfig represents the tls.Config setting for the STARTTLS connection tlsconfig *tls.Config - // user is the SMTP AUTH username - user string - // useDebugLog enables the debug logging on the SMTP client useDebugLog bool - // logger is a logger that implements the log.Logger interface - logger log.Logger + // user is the SMTP AUTH username + user string - // dialContextFunc is a custom DialContext function to dial target SMTP server - dialContextFunc DialContextFunc + // Use SSL for the connection + useSSL bool } // Option returns a function that can be used for grouping Client options From fec2f2075aefe22e399921b89fc57f9f4a8e18ae Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 10:52:30 +0200 Subject: [PATCH 21/30] Update build tags to support future Go versions Modified the build tags to exclude Go 1.20 and above instead of targeting only Go 1.19. This change ensures the code is compatible with future versions of Go by not restricting it to a specific minor version. --- random_119.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/random_119.go b/random_119.go index b084305..4b45c55 100644 --- a/random_119.go +++ b/random_119.go @@ -2,8 +2,8 @@ // // SPDX-License-Identifier: MIT -//go:build go1.19 && !go1.20 -// +build go1.19,!go1.20 +//go:build !go1.20 +// +build !go1.20 package mail From fdb80ad9ddc2dc267805fe3201645dc4ca0b72c2 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 11:10:23 +0200 Subject: [PATCH 22/30] Add mutex to Client for thread-safe operations This commit introduces a RWMutex to the Client struct in the smtp package to ensure thread-safe access to shared resources. Critical sections in methods like Close, StartTLS, and cmd are now protected with appropriate locking mechanisms. This change helps prevent potential race conditions, ensuring consistent and reliable behavior in concurrent environments. --- smtp/smtp.go | 41 +++++++++++++++++++++++++++++++++++++++++ smtp/smtp_ehlo.go | 3 +++ smtp/smtp_ehlo_117.go | 3 +++ 3 files changed, 47 insertions(+) diff --git a/smtp/smtp.go b/smtp/smtp.go index 787006c..1f1d603 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -30,6 +30,7 @@ import ( "net/textproto" "os" "strings" + "sync" "time" "github.com/wneessen/go-mail/log" @@ -70,6 +71,10 @@ type Client struct { // logger will be used for debug logging logger log.Logger + // mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access + // the resource at a time. + mutex sync.RWMutex + // tls indicates whether the Client is using TLS tls bool @@ -112,6 +117,9 @@ func NewClient(conn net.Conn, host string) (*Client, error) { // Close closes the connection. func (c *Client) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.Text.Close() } @@ -139,12 +147,19 @@ func (c *Client) Hello(localName string) error { if c.didHello { return errors.New("smtp: Hello called after other methods") } + + c.mutex.Lock() c.localName = localName + c.mutex.Unlock() + return c.hello() } // cmd is a convenience function that sends a command and returns the response func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.debugLog(log.DirClientToServer, format, args...) id, err := c.Text.Cmd(format, args...) if err != nil { @@ -160,7 +175,10 @@ func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, s // helo sends the HELO greeting to the server. It should be used only when the // server does not support ehlo. func (c *Client) helo() error { + c.mutex.Lock() c.ext = nil + c.mutex.Unlock() + _, _, err := c.cmd(250, "HELO %s", c.localName) return err } @@ -175,9 +193,13 @@ func (c *Client) StartTLS(config *tls.Config) error { if err != nil { return err } + + c.mutex.Lock() c.conn = tls.Client(c.conn, config) c.Text = textproto.NewConn(c.conn) c.tls = true + c.mutex.Unlock() + return c.ehlo() } @@ -185,6 +207,9 @@ func (c *Client) StartTLS(config *tls.Config) error { // The return values are their zero values if [Client.StartTLS] did // not succeed. func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { + c.mutex.RLock() + defer c.mutex.RUnlock() + tc, ok := c.conn.(*tls.Conn) if !ok { return @@ -249,7 +274,9 @@ func (c *Client) Auth(a Auth) error { // abort the AUTH. Not required for XOAUTH2 _, _, _ = c.cmd(501, "*") } + c.mutex.Lock() _ = c.Quit() + c.mutex.Unlock() break } if resp == nil { @@ -275,6 +302,8 @@ func (c *Client) Mail(from string) error { return err } cmdStr := "MAIL FROM:<%s>" + + c.mutex.RLock() if c.ext != nil { if _, ok := c.ext["8BITMIME"]; ok { cmdStr += " BODY=8BITMIME" @@ -287,6 +316,8 @@ func (c *Client) Mail(from string) error { cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype) } } + c.mutex.RUnlock() + _, _, err := c.cmd(250, cmdStr, from) return err } @@ -298,7 +329,11 @@ func (c *Client) Rcpt(to string) error { if err := validateLine(to); err != nil { return err } + + c.mutex.RLock() _, ok := c.ext["DSN"] + c.mutex.RUnlock() + if ok && c.dsnrntype != "" { _, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype) return err @@ -423,6 +458,9 @@ func (c *Client) Extension(ext string) (bool, string) { return false, "" } ext = strings.ToUpper(ext) + + c.mutex.RLock() + defer c.mutex.RUnlock() param, ok := c.ext[ext] return ok, param } @@ -497,6 +535,9 @@ func (c *Client) HasConnection() bool { } func (c *Client) UpdateDeadline(timeout time.Duration) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { return fmt.Errorf("smtp: failed to update deadline: %w", err) } diff --git a/smtp/smtp_ehlo.go b/smtp/smtp_ehlo.go index ae80a62..457be57 100644 --- a/smtp/smtp_ehlo.go +++ b/smtp/smtp_ehlo.go @@ -25,6 +25,9 @@ func (c *Client) ehlo() error { if err != nil { return err } + + c.mutex.Lock() + defer c.mutex.Unlock() ext := make(map[string]string) extList := strings.Split(msg, "\n") if len(extList) > 1 { diff --git a/smtp/smtp_ehlo_117.go b/smtp/smtp_ehlo_117.go index c516a36..c40297f 100644 --- a/smtp/smtp_ehlo_117.go +++ b/smtp/smtp_ehlo_117.go @@ -28,6 +28,9 @@ func (c *Client) ehlo() error { if err != nil { return err } + + c.mutex.Lock() + defer c.mutex.Unlock() ext := make(map[string]string) extList := strings.Split(msg, "\n") if len(extList) > 1 { From 2234f0c5bc67916863de1858834bf26a72432a25 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 11:43:22 +0200 Subject: [PATCH 23/30] Remove connection field from Client struct This commit removes the 'connection' field from the 'Client' struct and updates the related test logic accordingly. By using 'smtpClient.HasConnection()' to check for connections, code readability and maintainability are improved. All necessary test cases have been adjusted to reflect this change. --- client.go | 3 -- client_test.go | 77 +++++++++++++++++++++++++------------------------- 2 files changed, 38 insertions(+), 42 deletions(-) diff --git a/client.go b/client.go index d898708..81ff065 100644 --- a/client.go +++ b/client.go @@ -88,9 +88,6 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con // Client is the SMTP client struct type Client struct { - // connection is the net.Conn that the smtp.Client is based on - connection net.Conn - // Timeout for the SMTP server connection connTimeout time.Duration diff --git a/client_test.go b/client_test.go index c891a10..7706b0f 100644 --- a/client_test.go +++ b/client_test.go @@ -623,11 +623,12 @@ func TestClient_DialWithContext(t *testing.T) { t.Errorf("failed to dial with context: %s", err) return } - if c.connection == nil { - t.Errorf("DialWithContext didn't fail but no connection found.") - } if c.smtpClient == nil { t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return + } + if !c.smtpClient.HasConnection() { + t.Errorf("DialWithContext didn't fail but no connection found.") } if err := c.Close(); err != nil { t.Errorf("failed to close connection: %s", err) @@ -644,17 +645,18 @@ func TestClient_DialWithContext_Fallback(t *testing.T) { c.SetTLSPortPolicy(TLSOpportunistic) c.port = 999 ctx := context.Background() - if err := c.DialWithContext(ctx); err != nil { + if err = c.DialWithContext(ctx); err != nil { t.Errorf("failed to dial with context: %s", err) return } - if c.connection == nil { - t.Errorf("DialWithContext didn't fail but no connection found.") - } if c.smtpClient == nil { t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return } - if err := c.Close(); err != nil { + if !c.smtpClient.HasConnection() { + t.Errorf("DialWithContext didn't fail but no connection found.") + } + if err = c.Close(); err != nil { t.Errorf("failed to close connection: %s", err) } @@ -674,18 +676,19 @@ func TestClient_DialWithContext_Debug(t *testing.T) { t.Skipf("failed to create test client: %s. Skipping tests", err) } ctx := context.Background() - if err := c.DialWithContext(ctx); err != nil { + if err = c.DialWithContext(ctx); err != nil { t.Errorf("failed to dial with context: %s", err) return } - if c.connection == nil { - t.Errorf("DialWithContext didn't fail but no connection found.") - } if c.smtpClient == nil { t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return + } + if !c.smtpClient.HasConnection() { + t.Errorf("DialWithContext didn't fail but no connection found.") } c.SetDebugLog(true) - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Errorf("failed to close connection: %s", err) } } @@ -698,19 +701,20 @@ func TestClient_DialWithContext_Debug_custom(t *testing.T) { t.Skipf("failed to create test client: %s. Skipping tests", err) } ctx := context.Background() - if err := c.DialWithContext(ctx); err != nil { + if err = c.DialWithContext(ctx); err != nil { t.Errorf("failed to dial with context: %s", err) return } - if c.connection == nil { - t.Errorf("DialWithContext didn't fail but no connection found.") - } if c.smtpClient == nil { t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return + } + if !c.smtpClient.HasConnection() { + t.Errorf("DialWithContext didn't fail but no connection found.") } c.SetDebugLog(true) c.SetLogger(log.New(os.Stderr, log.LevelDebug)) - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Errorf("failed to close connection: %s", err) } } @@ -722,10 +726,9 @@ func TestClient_DialWithContextInvalidHost(t *testing.T) { if err != nil { t.Skipf("failed to create test client: %s. Skipping tests", err) } - c.connection = nil c.host = "invalid.addr" ctx := context.Background() - if err := c.DialWithContext(ctx); err == nil { + if err = c.DialWithContext(ctx); err == nil { t.Errorf("dial succeeded but was supposed to fail") return } @@ -738,10 +741,9 @@ func TestClient_DialWithContextInvalidHELO(t *testing.T) { if err != nil { t.Skipf("failed to create test client: %s. Skipping tests", err) } - c.connection = nil c.helo = "" ctx := context.Background() - if err := c.DialWithContext(ctx); err == nil { + if err = c.DialWithContext(ctx); err == nil { t.Errorf("dial succeeded but was supposed to fail") return } @@ -758,7 +760,7 @@ func TestClient_DialWithContextInvalidAuth(t *testing.T) { c.pass = "invalid" c.SetSMTPAuthCustom(smtp.LoginAuth("invalid", "invalid", "invalid")) ctx := context.Background() - if err := c.DialWithContext(ctx); err == nil { + if err = c.DialWithContext(ctx); err == nil { t.Errorf("dial succeeded but was supposed to fail") return } @@ -770,8 +772,7 @@ func TestClient_checkConn(t *testing.T) { if err != nil { t.Skipf("failed to create test client: %s. Skipping tests", err) } - c.connection = nil - if err := c.checkConn(); err == nil { + if err = c.checkConn(); err == nil { t.Errorf("connCheck() should fail but succeeded") } } @@ -802,21 +803,23 @@ func TestClient_DialWithContextOptions(t *testing.T) { } ctx := context.Background() - if err := c.DialWithContext(ctx); err != nil && !tt.sf { + if err = c.DialWithContext(ctx); err != nil && !tt.sf { t.Errorf("failed to dial with context: %s", err) return } if !tt.sf { - if c.connection == nil && !tt.sf { - t.Errorf("DialWithContext didn't fail but no connection found.") - } if c.smtpClient == nil && !tt.sf { t.Errorf("DialWithContext didn't fail but no SMTP client found.") + return } - if err := c.Reset(); err != nil { + if !c.smtpClient.HasConnection() && !tt.sf { + t.Errorf("DialWithContext didn't fail but no connection found.") + return + } + if err = c.Reset(); err != nil { t.Errorf("failed to reset connection: %s", err) } - if err := c.Close(); err != nil { + if err = c.Close(); err != nil { t.Errorf("failed to close connection: %s", err) } } @@ -1011,17 +1014,15 @@ func TestClient_DialSendCloseBroken(t *testing.T) { } if tt.closestart { _ = c.smtpClient.Close() - _ = c.connection.Close() } - if err := c.Send(m); err != nil && !tt.sf { + if err = c.Send(m); err != nil && !tt.sf { t.Errorf("Send() failed: %s", err) return } if tt.closeearly { _ = c.smtpClient.Close() - _ = c.connection.Close() } - if err := c.Close(); err != nil && !tt.sf { + if err = c.Close(); err != nil && !tt.sf { t.Errorf("Close() failed: %s", err) return } @@ -1071,17 +1072,15 @@ func TestClient_DialSendCloseBrokenWithDSN(t *testing.T) { } if tt.closestart { _ = c.smtpClient.Close() - _ = c.connection.Close() } - if err := c.Send(m); err != nil && !tt.sf { + if err = c.Send(m); err != nil && !tt.sf { t.Errorf("Send() failed: %s", err) return } if tt.closeearly { _ = c.smtpClient.Close() - _ = c.connection.Close() } - if err := c.Close(); err != nil && !tt.sf { + if err = c.Close(); err != nil && !tt.sf { t.Errorf("Close() failed: %s", err) return } From f59aa23ed8658c9758cd3fe17165cd1fa18e0aed Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 11:58:08 +0200 Subject: [PATCH 24/30] Add mutex locking to SetTLSConfig This change ensures that the SetTLSConfig method is thread-safe by adding a mutex lock. The lock is acquired before any changes to the TLS configuration and released afterward to prevent concurrent access issues. --- client.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client.go b/client.go index 81ff065..48548ce 100644 --- a/client.go +++ b/client.go @@ -555,6 +555,9 @@ func (c *Client) SetLogger(logger log.Logger) { // SetTLSConfig overrides the current *tls.Config with the given *tls.Config value func (c *Client) SetTLSConfig(tlsconfig *tls.Config) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if tlsconfig == nil { return ErrInvalidTLSConfig } From 6bd9a9c73584d3a9faee41f7f870718e9e28f94d Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 14:03:26 +0200 Subject: [PATCH 25/30] Refactor mutex usage for connection safety This commit revises locking mechanism usage around connection operations to avoid potential deadlocks and improve code clarity. Specifically, defer statements were removed and explicit unlocks were added to ensure that mutexes are properly released after critical sections. This change affects several methods, including `Close`, `cmd`, `TLSConnectionState`, `UpdateDeadline`, and newly introduced locking for concurrent data writes and reads in `dataCloser`. --- smtp/smtp.go | 49 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/smtp/smtp.go b/smtp/smtp.go index 1f1d603..379f5fe 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -118,9 +118,9 @@ func NewClient(conn net.Conn, host string) (*Client, error) { // Close closes the connection. func (c *Client) Close() error { c.mutex.Lock() - defer c.mutex.Unlock() - - return c.Text.Close() + err := c.Text.Close() + c.mutex.Unlock() + return err } // hello runs a hello exchange if needed. @@ -158,17 +158,18 @@ func (c *Client) Hello(localName string) error { // cmd is a convenience function that sends a command and returns the response func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { c.mutex.Lock() - defer c.mutex.Unlock() c.debugLog(log.DirClientToServer, format, args...) id, err := c.Text.Cmd(format, args...) if err != nil { + c.mutex.Unlock() return 0, "", err } c.Text.StartResponse(id) - defer c.Text.EndResponse(id) code, msg, err := c.Text.ReadResponse(expectCode) c.debugLog(log.DirServerToClient, "%d %s", code, msg) + c.Text.EndResponse(id) + c.mutex.Unlock() return code, msg, err } @@ -208,13 +209,14 @@ func (c *Client) StartTLS(config *tls.Config) error { // not succeed. func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { c.mutex.RLock() - defer c.mutex.RUnlock() tc, ok := c.conn.(*tls.Conn) if !ok { return } - return tc.ConnectionState(), true + state, ok = tc.ConnectionState(), true + c.mutex.RUnlock() + return } // Verify checks the validity of an email address on the server. @@ -274,9 +276,7 @@ func (c *Client) Auth(a Auth) error { // abort the AUTH. Not required for XOAUTH2 _, _, _ = c.cmd(501, "*") } - c.mutex.Lock() _ = c.Quit() - c.mutex.Unlock() break } if resp == nil { @@ -347,12 +347,23 @@ type dataCloser struct { io.WriteCloser } +// Close releases the lock, closes the WriteCloser, waits for a response, and then returns any error encountered. func (d *dataCloser) Close() error { + d.c.mutex.Lock() _ = d.WriteCloser.Close() _, _, err := d.c.Text.ReadResponse(250) + d.c.mutex.Unlock() return err } +// Write writes data to the underlying WriteCloser while ensuring thread-safety by locking and unlocking a mutex. +func (d *dataCloser) Write(p []byte) (n int, err error) { + d.c.mutex.Lock() + n, err = d.WriteCloser.Write(p) + d.c.mutex.Unlock() + return +} + // Data issues a DATA command to the server and returns a writer that // can be used to write the mail headers and body. The caller should // close the writer before calling any more methods on c. A call to @@ -362,7 +373,14 @@ func (c *Client) Data() (io.WriteCloser, error) { if err != nil { return nil, err } - return &dataCloser{c, c.Text.DotWriter()}, nil + datacloser := &dataCloser{} + + c.mutex.Lock() + datacloser.c = c + datacloser.WriteCloser = c.Text.DotWriter() + c.mutex.Unlock() + + return datacloser, nil } var testHookStartTLS func(*tls.Config) // nil, except for tests @@ -460,8 +478,8 @@ func (c *Client) Extension(ext string) (bool, string) { ext = strings.ToUpper(ext) c.mutex.RLock() - defer c.mutex.RUnlock() param, ok := c.ext[ext] + c.mutex.RUnlock() return ok, param } @@ -494,7 +512,11 @@ func (c *Client) Quit() error { if err != nil { return err } - return c.Text.Close() + c.mutex.Lock() + err = c.Text.Close() + c.mutex.Unlock() + + return err } // SetDebugLog enables the debug logging for incoming and outgoing SMTP messages @@ -536,11 +558,10 @@ func (c *Client) HasConnection() bool { func (c *Client) UpdateDeadline(timeout time.Duration) error { c.mutex.Lock() - defer c.mutex.Unlock() - if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { return fmt.Errorf("smtp: failed to update deadline: %w", err) } + c.mutex.Unlock() return nil } From 253d065c83bf5e1011509b73e3020a9d4e3ade76 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 14:03:50 +0200 Subject: [PATCH 26/30] Move mutex lock to sendSingleMsg method Mutex locking was relocated from the Send method in client_120.go and client_119.go to sendSingleMsg in client.go. This ensures thread-safety specifically during the message transmission process. --- client.go | 3 +++ client_119.go | 3 --- client_120.go | 3 --- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 48548ce..6557913 100644 --- a/client.go +++ b/client.go @@ -801,6 +801,9 @@ func (c *Client) auth() error { // sendSingleMsg sends out a single message and returns an error if the transmission/delivery fails. // It is invoked by the public Send methods func (c *Client) sendSingleMsg(message *Msg) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if message.encoding == NoEncoding { if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok { return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message} diff --git a/client_119.go b/client_119.go index 0b05061..7de5d59 100644 --- a/client_119.go +++ b/client_119.go @@ -11,9 +11,6 @@ import "errors" // Send sends out the mail message func (c *Client) Send(messages ...*Msg) error { - c.mutex.Lock() - defer c.mutex.Unlock() - if err := c.checkConn(); err != nil { return &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} } diff --git a/client_120.go b/client_120.go index 5bb291d..4f82aa7 100644 --- a/client_120.go +++ b/client_120.go @@ -13,9 +13,6 @@ import ( // Send sends out the mail message func (c *Client) Send(messages ...*Msg) (returnErr error) { - c.mutex.Lock() - defer c.mutex.Unlock() - if err := c.checkConn(); err != nil { returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} return From 2d98c40cb6de7629a65136a0508ea4d4317612bd Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 14:04:02 +0200 Subject: [PATCH 27/30] Add concurrent send tests for Client Introduced TestClient_DialSendConcurrent_online and TestClient_DialSendConcurrent_local to validate concurrent sending of messages. These tests ensure that the Client's send functionality works correctly under concurrent conditions, both in an online environment and using a local test server. --- client_test.go | 110 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/client_test.go b/client_test.go index 7706b0f..7055a10 100644 --- a/client_test.go +++ b/client_test.go @@ -15,6 +15,7 @@ import ( "os" "strconv" "strings" + "sync" "testing" "time" @@ -1727,6 +1728,114 @@ func TestClient_SendErrorReset(t *testing.T) { } } +func TestClient_DialSendConcurrent_online(t *testing.T) { + if os.Getenv("TEST_ALLOW_SEND") == "" { + t.Skipf("TEST_ALLOW_SEND is not set. Skipping mail sending test") + } + + client, err := getTestConnection(true) + if err != nil { + t.Errorf("unable to create new client: %s", err) + } + + var messages []*Msg + for i := 0; i < 10; i++ { + message := NewMsg() + if err := message.FromFormat("go-mail Test Mailer", os.Getenv("TEST_FROM")); err != nil { + t.Errorf("failed to set FROM address: %s", err) + return + } + if err := message.To(TestRcpt); err != nil { + t.Errorf("failed to set TO address: %s", err) + return + } + message.Subject(fmt.Sprintf("Test subject for mail %d", i)) + message.SetBodyString(TypeTextPlain, fmt.Sprintf("This is the test body of the mail no. %d", i)) + message.SetMessageID() + messages = append(messages, message) + } + + if err = client.DialWithContext(context.Background()); err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + + wg := sync.WaitGroup{} + for id, message := range messages { + wg.Add(1) + go func(curMsg *Msg, curID int) { + defer wg.Done() + if goroutineErr := client.Send(curMsg); err != nil { + t.Errorf("failed to send message with ID %d: %s", curID, goroutineErr) + } + }(message, id) + } + wg.Wait() + + if err = client.Close(); err != nil { + t.Errorf("failed to close server connection: %s", err) + } +} + +func TestClient_DialSendConcurrent_local(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 20 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 500) + + client, err := NewClient(TestServerAddr, WithPort(serverPort), + WithTLSPortPolicy(NoTLS), WithSMTPAuth(SMTPAuthPlain), + WithUsername("toni@tester.com"), + WithPassword("V3ryS3cr3t+")) + if err != nil { + t.Errorf("unable to create new client: %s", err) + } + + var messages []*Msg + for i := 0; i < 50; i++ { + message := NewMsg() + if err := message.From("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set FROM address: %s", err) + return + } + if err := message.To("valid-to@domain.tld"); err != nil { + t.Errorf("failed to set TO address: %s", err) + return + } + message.Subject("Test subject") + message.SetBodyString(TypeTextPlain, "Test body") + message.SetMessageIDWithValue("this.is.a.message.id") + messages = append(messages, message) + } + + if err = client.DialWithContext(context.Background()); err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + + wg := sync.WaitGroup{} + for id, message := range messages { + wg.Add(1) + go func(curMsg *Msg, curID int) { + defer wg.Done() + if goroutineErr := client.Send(curMsg); err != nil { + t.Errorf("failed to send message with ID %d: %s", curID, goroutineErr) + } + }(message, id) + } + wg.Wait() + + if err = client.Close(); err != nil { + t.Errorf("failed to close server connection: %s", err) + } +} + // getTestConnection takes environment variables to establish a connection to a real // SMTP server to test all functionality that requires a connection func getTestConnection(auth bool) (*Client, error) { @@ -2099,6 +2208,7 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese if err != nil { break } + time.Sleep(time.Millisecond) var datastring string data = strings.TrimSpace(data) From 8791ce5a3354005905f05398c027149261a5220d Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 17:00:07 +0200 Subject: [PATCH 28/30] Fix deferred mutex unlock in TLSConnectionState Correct the sequence of mutex unlocking in TLSConnectionState to ensure the mutex is always released properly. This prevents potential deadlocks and ensures the function behaves as expected in a concurrent context. --- smtp/smtp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smtp/smtp.go b/smtp/smtp.go index 379f5fe..4ea1a3d 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -209,13 +209,13 @@ func (c *Client) StartTLS(config *tls.Config) error { // not succeed. func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { c.mutex.RLock() + defer c.mutex.RUnlock() tc, ok := c.conn.(*tls.Conn) if !ok { return } state, ok = tc.ConnectionState(), true - c.mutex.RUnlock() return } From 6e98d7e47d8ed22002c72482f821f8f15db7c587 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 17:00:21 +0200 Subject: [PATCH 29/30] Reduce message loop iterations and add XOAUTH2 tests Loop iterations in `client_test.go` were reduced from 50 to 20 for efficiency. Added new tests to verify XOAUTH2 authentication support and error handling by simulating SMTP server responses. --- client_test.go | 77 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 2 deletions(-) diff --git a/client_test.go b/client_test.go index 7055a10..c764263 100644 --- a/client_test.go +++ b/client_test.go @@ -1799,7 +1799,7 @@ func TestClient_DialSendConcurrent_local(t *testing.T) { } var messages []*Msg - for i := 0; i < 50; i++ { + for i := 0; i < 20; i++ { message := NewMsg() if err := message.From("valid-from@domain.tld"); err != nil { t.Errorf("failed to set FROM address: %s", err) @@ -2021,6 +2021,72 @@ func getTestConnectionWithDSN(auth bool) (*Client, error) { } func TestXOAuth2OK(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 30 + featureSet := "250-AUTH XOAUTH2\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 500) + + c, err := NewClient("127.0.0.1", + WithPort(serverPort), + WithTLSPortPolicy(TLSOpportunistic), + WithSMTPAuth(SMTPAuthXOAUTH2), + WithUsername("user"), + WithPassword("token")) + if err != nil { + t.Fatalf("unable to create new client: %v", err) + } + if err = c.DialWithContext(context.Background()); err != nil { + t.Fatalf("unexpected dial error: %v", err) + } + if err = c.Close(); err != nil { + t.Fatalf("disconnect from test server failed: %v", err) + } +} + +func TestXOAuth2Unsupported(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 31 + featureSet := "250-AUTH LOGIN PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 500) + + c, err := NewClient("127.0.0.1", + WithPort(serverPort), + WithTLSPolicy(TLSOpportunistic), + WithSMTPAuth(SMTPAuthXOAUTH2), + WithUsername("user"), + WithPassword("token")) + if err != nil { + t.Fatalf("unable to create new client: %v", err) + } + if err = c.DialWithContext(context.Background()); err == nil { + t.Fatal("expected dial error got nil") + } else { + if !errors.Is(err, ErrXOauth2AuthNotSupported) { + t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) + } + } + if err = c.Close(); err != nil { + t.Fatalf("disconnect from test server failed: %v", err) + } +} + +func TestXOAuth2OK_faker(t *testing.T) { server := []string{ "220 Fake server ready ESMTP", "250-fake.server", @@ -2060,7 +2126,7 @@ func TestXOAuth2OK(t *testing.T) { } } -func TestXOAuth2Unsupported(t *testing.T) { +func TestXOAuth2Unsupported_faker(t *testing.T) { server := []string{ "220 Fake server ready ESMTP", "250-fake.server", @@ -2231,6 +2297,13 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese break } writeOK() + case strings.HasPrefix(data, "AUTH XOAUTH2"): + auth := strings.TrimPrefix(data, "AUTH XOAUTH2 ") + if !strings.EqualFold(auth, "dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=") { + _ = writeLine("535 5.7.8 Error: authentication failed") + break + } + _ = writeLine("235 2.7.0 Authentication successful") case strings.HasPrefix(data, "AUTH PLAIN"): auth := strings.TrimPrefix(data, "AUTH PLAIN ") if !strings.EqualFold(auth, "AHRvbmlAdGVzdGVyLmNvbQBWM3J5UzNjcjN0Kw==") { From c1f6ef07d46dbb70206c12b83ac72bed0e45fd8b Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Fri, 27 Sep 2024 17:09:00 +0200 Subject: [PATCH 30/30] Skip test cases when client creation fails Updated the client creation check to skip test cases if the client cannot be created, instead of marking them as errors. This ensures tests dependent on a successful client creation do not fail unnecessarily but are instead skipped. --- client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client_test.go b/client_test.go index c764263..2d37ce0 100644 --- a/client_test.go +++ b/client_test.go @@ -1735,7 +1735,7 @@ func TestClient_DialSendConcurrent_online(t *testing.T) { client, err := getTestConnection(true) if err != nil { - t.Errorf("unable to create new client: %s", err) + t.Skipf("failed to create test client: %s. Skipping tests", err) } var messages []*Msg