Compare commits

...

2 commits

Author SHA1 Message Date
3871b2be44
Lock client connections and update deadline handling
Add mutex locking for client connections to ensure thread safety. Introduce `HasConnection` method to check active connections and `UpdateDeadline` method to handle timeout updates. Refactor connection handling in `checkConn` and `tls` methods accordingly.
2024-09-26 11:51:30 +02:00
8683917c3d
Delete connection pool implementation and tests
Remove `connpool.go` and `connpool_test.go`. This eliminates the connection pool feature from the codebase, including associated functionality and tests. The connection pool feature is much to complex and doesn't provide the benefits expected by the concurrency feature
2024-09-26 11:49:48 +02:00
5 changed files with 27 additions and 645 deletions

View file

@ -12,6 +12,7 @@ import (
"net" "net"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
"github.com/wneessen/go-mail/log" "github.com/wneessen/go-mail/log"
@ -87,6 +88,7 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
// Client is the SMTP client struct // Client is the SMTP client struct
type Client struct { type Client struct {
mutex sync.RWMutex
// connection is the net.Conn that the smtp.Client is based on // connection is the net.Conn that the smtp.Client is based on
connection net.Conn connection net.Conn
@ -589,6 +591,9 @@ func (c *Client) setDefaultHelo() error {
// DialWithContext establishes a connection to the SMTP server with a given context.Context // DialWithContext establishes a connection to the SMTP server with a given context.Context
func (c *Client) DialWithContext(dialCtx context.Context) error { func (c *Client) DialWithContext(dialCtx context.Context) error {
c.mutex.Lock()
defer c.mutex.Unlock()
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
defer cancel() defer cancel()
@ -602,17 +607,16 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
c.dialContextFunc = tlsDialer.DialContext c.dialContextFunc = tlsDialer.DialContext
} }
} }
var err error connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr())
c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 { if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error? // TODO: should we somehow log or append the previous error?
c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
} }
if err != nil { if err != nil {
return err return err
} }
client, err := smtp.NewClient(c.connection, c.host) client, err := smtp.NewClient(connection, c.host)
if err != nil { if err != nil {
return err return err
} }
@ -691,7 +695,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
// checkConn makes sure that a required server connection is available and extends the // checkConn makes sure that a required server connection is available and extends the
// connection deadline // connection deadline
func (c *Client) checkConn() error { func (c *Client) checkConn() error {
if c.connection == nil { if !c.smtpClient.HasConnection() {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
@ -701,7 +705,7 @@ func (c *Client) checkConn() error {
} }
} }
if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil { if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil {
return ErrDeadlineExtendFailed return ErrDeadlineExtendFailed
} }
return nil return nil
@ -715,7 +719,7 @@ func (c *Client) serverFallbackAddr() string {
// tls tries to make sure that the STARTTLS requirements are satisfied // tls tries to make sure that the STARTTLS requirements are satisfied
func (c *Client) tls() error { func (c *Client) tls() error {
if c.connection == nil { if !c.smtpClient.HasConnection() {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
if !c.useSSL && c.tlspolicy != NoTLS { if !c.useSSL && c.tlspolicy != NoTLS {

View file

@ -13,6 +13,8 @@ import (
// Send sends out the mail message // Send sends out the mail message
func (c *Client) Send(messages ...*Msg) (returnErr error) { func (c *Client) Send(messages ...*Msg) (returnErr error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if err := c.checkConn(); err != nil { if err := c.checkConn(); err != nil {
returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)}
return return

View file

@ -1,215 +0,0 @@
// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors
//
// SPDX-License-Identifier: MIT
package mail
import (
"context"
"errors"
"fmt"
"net"
"sync"
)
// Parts of the connection pool code is forked/took inspiration from https://github.com/fatih/pool/
// Thanks to Fatih Arslan and the project contributors for providing this great concurrency template.
var (
// ErrClosed is returned when an operation is attempted on a closed connection pool.
ErrClosed = errors.New("connection pool is closed")
// ErrNilConn is returned when a nil connection is passed back to the connection pool.
ErrNilConn = errors.New("connection is nil")
// ErrPoolInvalidCap is returned when the connection pool's capacity settings are
// invalid (e.g., initial capacity is negative).
ErrPoolInvalidCap = errors.New("invalid connection pool capacity settings")
)
// Pool interface describes a connection pool implementation. A Pool is
// thread-/go-routine safe.
type Pool interface {
// Get returns a new connection from the pool. Closing the connections returns
// it back into the Pool. Closing a connection when the Pool is destroyed or
// full will be counted as an error.
Get(ctx context.Context) (net.Conn, error)
// Close closes the pool and all its connections. After Close() the pool is
// no longer usable.
Close()
// Size returns the current number of connections of the pool.
Size() int
}
// connPool implements the Pool interface
type connPool struct {
// mutex is used to synchronize access to the connection pool to ensure thread-safe operations.
mutex sync.RWMutex
// conns is a channel used to manage and distribute net.Conn objects within the connection pool.
conns chan net.Conn
// dialCtxFunc represents the actual net.Conn returned by the DialContextFunc.
dialCtxFunc DialContextFunc
// dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in
// the connection pool.
dialNetwork string
// dialAddress specifies the address used to establish network connections within the connection pool.
dialAddress string
}
// PoolConn is a wrapper around net.Conn to modify the the behavior of net.Conn's Close() method.
type PoolConn struct {
net.Conn
mutex sync.RWMutex
pool *connPool
unusable bool
}
// Close puts a given pool connection back to the pool instead of closing it.
func (c *PoolConn) Close() error {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.unusable {
if c.Conn != nil {
return c.Conn.Close()
}
return nil
}
return c.pool.put(c.Conn)
}
// MarkUnusable marks the connection not usable any more, to let the pool close it instead
// of returning it to pool.
func (c *PoolConn) MarkUnusable() {
c.mutex.Lock()
c.unusable = true
c.mutex.Unlock()
}
// NewConnPool returns a new pool based on buffered channels with an initial
// capacity and maximum capacity. The DialContextFunc is used when the initial
// capacity is greater than zero to fill the pool. A zero initialCap doesn't
// fill the Pool until a new Get() is called. During a Get(), if there is no
// new connection available in the pool, a new connection will be created via
// the corresponding DialContextFunc() method.
func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc,
network, address string,
) (Pool, error) {
if initialCap < 0 || maxCap <= 0 || initialCap > maxCap {
return nil, ErrPoolInvalidCap
}
pool := &connPool{
conns: make(chan net.Conn, maxCap),
dialCtxFunc: dialCtxFunc,
dialAddress: address,
dialNetwork: network,
}
// Initial connections for the pool. Pool will be closed on connection error
for i := 0; i < initialCap; i++ {
conn, err := dialCtxFunc(ctx, network, address)
if err != nil {
pool.Close()
return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %w", err)
}
pool.conns <- conn
}
return pool, nil
}
// Get satisfies the Get() method of the Pool inteface. If there is no new
// connection available in the Pool, a new connection will be created via the
// DialContextFunc() method.
func (p *connPool) Get(ctx context.Context) (net.Conn, error) {
conns, dialCtxFunc := p.getConnsAndDialContext()
if conns == nil {
return nil, ErrClosed
}
// wrap the connections into the custom net.Conn implementation that puts
// connections back to the pool
select {
case <-ctx.Done():
return nil, fmt.Errorf("failed to get connection: %w", ctx.Err())
case conn := <-conns:
if conn == nil {
return nil, ErrClosed
}
return p.wrapConn(conn), nil
default:
conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress)
if err != nil {
return nil, fmt.Errorf("dialContextFunc failed: %w", err)
}
return p.wrapConn(conn), nil
}
}
// Close terminates all connections in the pool and frees associated resources. Once closed,
// the pool is no longer usable.
func (p *connPool) Close() {
p.mutex.Lock()
conns := p.conns
p.conns = nil
p.dialCtxFunc = nil
p.dialAddress = ""
p.dialNetwork = ""
p.mutex.Unlock()
if conns == nil {
return
}
close(conns)
for conn := range conns {
_ = conn.Close()
}
}
// Size returns the current number of connections in the connection pool.
func (p *connPool) Size() int {
conns, _ := p.getConnsAndDialContext()
return len(conns)
}
// getConnsAndDialContext returns the connection channel and the DialContext function for the
// connection pool.
func (p *connPool) getConnsAndDialContext() (chan net.Conn, DialContextFunc) {
p.mutex.RLock()
conns := p.conns
dialCtxFunc := p.dialCtxFunc
p.mutex.RUnlock()
return conns, dialCtxFunc
}
// put puts a passed connection back into the pool. If the pool is full or closed,
// conn is simply closed. A nil conn will be rejected with an error.
func (p *connPool) put(conn net.Conn) error {
if conn == nil {
return ErrNilConn
}
p.mutex.RLock()
defer p.mutex.RUnlock()
if p.conns == nil {
return conn.Close()
}
select {
case p.conns <- conn:
return nil
default:
return conn.Close()
}
}
// wrapConn wraps a given net.Conn with a PoolConn, modifying the net.Conn's Close() method.
func (p *connPool) wrapConn(conn net.Conn) net.Conn {
poolconn := &PoolConn{pool: p}
poolconn.Conn = conn
return poolconn
}

View file

@ -1,423 +0,0 @@
// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors
//
// SPDX-License-Identifier: MIT
package mail
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
)
func TestNewConnPool(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 10
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer pool.Close()
if pool == nil {
t.Errorf("connection pool is nil")
return
}
if pool.Size() != 5 {
t.Errorf("expected 5 connections, got %d", pool.Size())
}
conn, err := pool.Get(context.Background())
if err != nil {
t.Errorf("failed to get connection: %s", err)
}
if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}
func TestConnPool_Get_Type(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 11
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer pool.Close()
conn, err := pool.Get(context.Background())
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
return
}
_, ok := conn.(*PoolConn)
if !ok {
t.Error("received connection from pool is not of type PoolConn")
return
}
if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}
func TestConnPool_Get(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 12
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
p, _ := newConnPool(serverPort)
defer p.Close()
conn, err := p.Get(context.Background())
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
return
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
if p.Size() != 4 {
t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size())
}
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
wgconn, err := p.Get(context.Background())
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}()
}
wg.Wait()
if p.Size() != 0 {
t.Errorf("Get error. Expecting 0, got %d", p.Size())
}
conn, err = p.Get(context.Background())
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
p.Close()
}
func TestPoolConn_Close(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 13
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
netDialer := net.Dialer{}
p, err := NewConnPool(context.Background(), 0, 30, netDialer.DialContext, "tcp",
fmt.Sprintf("127.0.0.1:%d", serverPort))
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer p.Close()
conns := make([]net.Conn, 30)
for i := 0; i < 30; i++ {
conn, _ := p.Get(context.Background())
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
conns[i] = conn
}
for _, conn := range conns {
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
}
if p.Size() != 30 {
t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size())
}
conn, err := p.Get(context.Background())
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
p.Close()
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
if p.Size() != 0 {
t.Errorf("closed pool shouldn't allow to put connections.")
}
}
func TestPoolConn_MarkUnusable(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 14
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, _ := newConnPool(serverPort)
defer pool.Close()
conn, err := pool.Get(context.Background())
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
poolSize := pool.Size()
conn, err = pool.Get(context.Background())
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
if pool.Size() != poolSize {
t.Errorf("pool size is expected to be equal to initial size")
}
conn, err = pool.Get(context.Background())
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if pc, ok := conn.(*PoolConn); !ok {
t.Errorf("this should never happen")
} else {
pc.MarkUnusable()
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
if pool.Size() != poolSize-1 {
t.Errorf("pool size is expected to be: %d but got: %d", poolSize-1, pool.Size())
}
}
func TestConnPool_Close(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 15
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
pool.Close()
castPool := pool.(*connPool)
if castPool.conns != nil {
t.Error("closing pool failed: conns channel should be nil")
}
if castPool.dialCtxFunc != nil {
t.Error("closing pool failed: dialCtxFunc should be nil")
}
if castPool.dialAddress != "" {
t.Error("closing pool failed: dialAddress should be empty")
}
if castPool.dialNetwork != "" {
t.Error("closing pool failed: dialNetwork should be empty")
}
conn, err := pool.Get(context.Background())
if err == nil {
t.Errorf("closing pool failed: getting new connection should return an error")
}
if conn != nil {
t.Errorf("closing pool failed: getting new connection should return a nil-connection")
}
if pool.Size() != 0 {
t.Errorf("closing pool failed: pool size should be 0, but got: %d", pool.Size())
}
}
func TestConnPool_Concurrency(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 16
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer pool.Close()
pipe := make(chan net.Conn)
getWg := sync.WaitGroup{}
closeWg := sync.WaitGroup{}
for i := 0; i < 30; i++ {
getWg.Add(1)
closeWg.Add(1)
go func() {
conn, err := pool.Get(context.Background())
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
pipe <- conn
getWg.Done()
}()
go func() {
conn := <-pipe
if conn == nil {
return
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
closeWg.Done()
}()
getWg.Wait()
closeWg.Wait()
}
}
func TestConnPool_GetContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 17
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
p, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer p.Close()
connCtx, connCancel := context.WithCancel(context.Background())
defer connCancel()
conn, err := p.Get(connCtx)
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
return
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
if p.Size() != 4 {
t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size())
}
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
wgconn, err := p.Get(connCtx)
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}()
}
wg.Wait()
if p.Size() != 0 {
t.Errorf("Get error. Expecting 0, got %d", p.Size())
}
connCancel()
_, err = p.Get(connCtx)
if err == nil {
t.Errorf("getting new connection on canceled context should fail, but didn't")
}
p.Close()
}
func newConnPool(port int) (Pool, error) {
netDialer := net.Dialer{}
return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp",
fmt.Sprintf("127.0.0.1:%d", port))
}

View file

@ -30,6 +30,7 @@ import (
"net/textproto" "net/textproto"
"os" "os"
"strings" "strings"
"time"
"github.com/wneessen/go-mail/log" "github.com/wneessen/go-mail/log"
) )
@ -472,6 +473,19 @@ func (c *Client) SetDSNRcptNotifyOption(d string) {
c.dsnrntype = d c.dsnrntype = d
} }
// HasConnection checks if the client has an active connection.
// Returns true if the `conn` field is not nil, indicating an active connection.
func (c *Client) HasConnection() bool {
return c.conn != nil
}
func (c *Client) UpdateDeadline(timeout time.Duration) error {
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return fmt.Errorf("smtp: failed to update deadline: %w", err)
}
return nil
}
// debugLog checks if the debug flag is set and if so logs the provided message to // debugLog checks if the debug flag is set and if so logs the provided message to
// the log.Logger interface // the log.Logger interface
func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) { func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) {