diff --git a/connpool.go b/connpool.go deleted file mode 100644 index f50e73d..0000000 --- a/connpool.go +++ /dev/null @@ -1,215 +0,0 @@ -// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors -// -// SPDX-License-Identifier: MIT - -package mail - -import ( - "context" - "errors" - "fmt" - "net" - "sync" -) - -// Parts of the connection pool code is forked/took inspiration from https://github.com/fatih/pool/ -// Thanks to Fatih Arslan and the project contributors for providing this great concurrency template. - -var ( - // ErrClosed is returned when an operation is attempted on a closed connection pool. - ErrClosed = errors.New("connection pool is closed") - // ErrNilConn is returned when a nil connection is passed back to the connection pool. - ErrNilConn = errors.New("connection is nil") - // ErrPoolInvalidCap is returned when the connection pool's capacity settings are - // invalid (e.g., initial capacity is negative). - ErrPoolInvalidCap = errors.New("invalid connection pool capacity settings") -) - -// Pool interface describes a connection pool implementation. A Pool is -// thread-/go-routine safe. -type Pool interface { - // Get returns a new connection from the pool. Closing the connections returns - // it back into the Pool. Closing a connection when the Pool is destroyed or - // full will be counted as an error. - Get(ctx context.Context) (net.Conn, error) - - // Close closes the pool and all its connections. After Close() the pool is - // no longer usable. - Close() - - // Size returns the current number of connections of the pool. - Size() int -} - -// connPool implements the Pool interface -type connPool struct { - // mutex is used to synchronize access to the connection pool to ensure thread-safe operations. - mutex sync.RWMutex - // conns is a channel used to manage and distribute net.Conn objects within the connection pool. - conns chan net.Conn - - // dialCtxFunc represents the actual net.Conn returned by the DialContextFunc. - dialCtxFunc DialContextFunc - // dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in - // the connection pool. - dialNetwork string - // dialAddress specifies the address used to establish network connections within the connection pool. - dialAddress string -} - -// PoolConn is a wrapper around net.Conn to modify the the behavior of net.Conn's Close() method. -type PoolConn struct { - net.Conn - mutex sync.RWMutex - pool *connPool - unusable bool -} - -// Close puts a given pool connection back to the pool instead of closing it. -func (c *PoolConn) Close() error { - c.mutex.RLock() - defer c.mutex.RUnlock() - - if c.unusable { - if c.Conn != nil { - return c.Conn.Close() - } - return nil - } - return c.pool.put(c.Conn) -} - -// MarkUnusable marks the connection not usable any more, to let the pool close it instead -// of returning it to pool. -func (c *PoolConn) MarkUnusable() { - c.mutex.Lock() - c.unusable = true - c.mutex.Unlock() -} - -// NewConnPool returns a new pool based on buffered channels with an initial -// capacity and maximum capacity. The DialContextFunc is used when the initial -// capacity is greater than zero to fill the pool. A zero initialCap doesn't -// fill the Pool until a new Get() is called. During a Get(), if there is no -// new connection available in the pool, a new connection will be created via -// the corresponding DialContextFunc() method. -func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc, - network, address string, -) (Pool, error) { - if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { - return nil, ErrPoolInvalidCap - } - - pool := &connPool{ - conns: make(chan net.Conn, maxCap), - dialCtxFunc: dialCtxFunc, - dialAddress: address, - dialNetwork: network, - } - - // Initial connections for the pool. Pool will be closed on connection error - for i := 0; i < initialCap; i++ { - conn, err := dialCtxFunc(ctx, network, address) - if err != nil { - pool.Close() - return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %w", err) - } - pool.conns <- conn - } - - return pool, nil -} - -// Get satisfies the Get() method of the Pool inteface. If there is no new -// connection available in the Pool, a new connection will be created via the -// DialContextFunc() method. -func (p *connPool) Get(ctx context.Context) (net.Conn, error) { - conns, dialCtxFunc := p.getConnsAndDialContext() - if conns == nil { - return nil, ErrClosed - } - - // wrap the connections into the custom net.Conn implementation that puts - // connections back to the pool - select { - case <-ctx.Done(): - return nil, fmt.Errorf("failed to get connection: %w", ctx.Err()) - case conn := <-conns: - if conn == nil { - return nil, ErrClosed - } - return p.wrapConn(conn), nil - default: - conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress) - if err != nil { - return nil, fmt.Errorf("dialContextFunc failed: %w", err) - } - return p.wrapConn(conn), nil - } -} - -// Close terminates all connections in the pool and frees associated resources. Once closed, -// the pool is no longer usable. -func (p *connPool) Close() { - p.mutex.Lock() - conns := p.conns - p.conns = nil - p.dialCtxFunc = nil - p.dialAddress = "" - p.dialNetwork = "" - p.mutex.Unlock() - - if conns == nil { - return - } - - close(conns) - for conn := range conns { - _ = conn.Close() - } -} - -// Size returns the current number of connections in the connection pool. -func (p *connPool) Size() int { - conns, _ := p.getConnsAndDialContext() - return len(conns) -} - -// getConnsAndDialContext returns the connection channel and the DialContext function for the -// connection pool. -func (p *connPool) getConnsAndDialContext() (chan net.Conn, DialContextFunc) { - p.mutex.RLock() - conns := p.conns - dialCtxFunc := p.dialCtxFunc - p.mutex.RUnlock() - return conns, dialCtxFunc -} - -// put puts a passed connection back into the pool. If the pool is full or closed, -// conn is simply closed. A nil conn will be rejected with an error. -func (p *connPool) put(conn net.Conn) error { - if conn == nil { - return ErrNilConn - } - - p.mutex.RLock() - defer p.mutex.RUnlock() - - if p.conns == nil { - return conn.Close() - } - - select { - case p.conns <- conn: - return nil - default: - return conn.Close() - } -} - -// wrapConn wraps a given net.Conn with a PoolConn, modifying the net.Conn's Close() method. -func (p *connPool) wrapConn(conn net.Conn) net.Conn { - poolconn := &PoolConn{pool: p} - poolconn.Conn = conn - return poolconn -} diff --git a/connpool_test.go b/connpool_test.go deleted file mode 100644 index 42b26c8..0000000 --- a/connpool_test.go +++ /dev/null @@ -1,423 +0,0 @@ -// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors -// -// SPDX-License-Identifier: MIT - -package mail - -import ( - "context" - "fmt" - "net" - "sync" - "testing" - "time" -) - -func TestNewConnPool(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 10 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - pool, err := newConnPool(serverPort) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - defer pool.Close() - if pool == nil { - t.Errorf("connection pool is nil") - return - } - if pool.Size() != 5 { - t.Errorf("expected 5 connections, got %d", pool.Size()) - } - conn, err := pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get connection: %s", err) - } - if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } -} - -func TestConnPool_Get_Type(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 11 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - pool, err := newConnPool(serverPort) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - defer pool.Close() - - conn, err := pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - return - } - - _, ok := conn.(*PoolConn) - if !ok { - t.Error("received connection from pool is not of type PoolConn") - return - } - if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } -} - -func TestConnPool_Get(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 12 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - p, _ := newConnPool(serverPort) - defer p.Close() - - conn, err := p.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - return - } - if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - - if p.Size() != 4 { - t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size()) - } - - var wg sync.WaitGroup - for i := 0; i < 4; i++ { - wg.Add(1) - go func() { - defer wg.Done() - wgconn, err := p.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - }() - } - wg.Wait() - - if p.Size() != 0 { - t.Errorf("Get error. Expecting 0, got %d", p.Size()) - } - - conn, err = p.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - p.Close() -} - -func TestPoolConn_Close(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 13 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - netDialer := net.Dialer{} - p, err := NewConnPool(context.Background(), 0, 30, netDialer.DialContext, "tcp", - fmt.Sprintf("127.0.0.1:%d", serverPort)) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - defer p.Close() - - conns := make([]net.Conn, 30) - for i := 0; i < 30; i++ { - conn, _ := p.Get(context.Background()) - if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - conns[i] = conn - } - for _, conn := range conns { - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - } - - if p.Size() != 30 { - t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size()) - } - - conn, err := p.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - p.Close() - - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - if p.Size() != 0 { - t.Errorf("closed pool shouldn't allow to put connections.") - } -} - -func TestPoolConn_MarkUnusable(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 14 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - pool, _ := newConnPool(serverPort) - defer pool.Close() - - conn, err := pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - - poolSize := pool.Size() - conn, err = pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - if pool.Size() != poolSize { - t.Errorf("pool size is expected to be equal to initial size") - } - - conn, err = pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if pc, ok := conn.(*PoolConn); !ok { - t.Errorf("this should never happen") - } else { - pc.MarkUnusable() - } - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - if pool.Size() != poolSize-1 { - t.Errorf("pool size is expected to be: %d but got: %d", poolSize-1, pool.Size()) - } -} - -func TestConnPool_Close(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 15 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - pool, err := newConnPool(serverPort) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - pool.Close() - - castPool := pool.(*connPool) - - if castPool.conns != nil { - t.Error("closing pool failed: conns channel should be nil") - } - if castPool.dialCtxFunc != nil { - t.Error("closing pool failed: dialCtxFunc should be nil") - } - if castPool.dialAddress != "" { - t.Error("closing pool failed: dialAddress should be empty") - } - if castPool.dialNetwork != "" { - t.Error("closing pool failed: dialNetwork should be empty") - } - - conn, err := pool.Get(context.Background()) - if err == nil { - t.Errorf("closing pool failed: getting new connection should return an error") - } - if conn != nil { - t.Errorf("closing pool failed: getting new connection should return a nil-connection") - } - if pool.Size() != 0 { - t.Errorf("closing pool failed: pool size should be 0, but got: %d", pool.Size()) - } -} - -func TestConnPool_Concurrency(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 16 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - pool, err := newConnPool(serverPort) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - defer pool.Close() - pipe := make(chan net.Conn) - - getWg := sync.WaitGroup{} - closeWg := sync.WaitGroup{} - for i := 0; i < 30; i++ { - getWg.Add(1) - closeWg.Add(1) - go func() { - conn, err := pool.Get(context.Background()) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - pipe <- conn - getWg.Done() - }() - - go func() { - conn := <-pipe - if conn == nil { - return - } - if err = conn.Close(); err != nil { - t.Errorf("failed to close connection: %s", err) - } - closeWg.Done() - }() - getWg.Wait() - closeWg.Wait() - } -} - -func TestConnPool_GetContextCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverPort := TestServerPortBase + 17 - featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" - go func() { - if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { - t.Errorf("failed to start test server: %s", err) - return - } - }() - time.Sleep(time.Millisecond * 300) - - p, err := newConnPool(serverPort) - if err != nil { - t.Errorf("failed to create connection pool: %s", err) - } - defer p.Close() - - connCtx, connCancel := context.WithCancel(context.Background()) - defer connCancel() - - conn, err := p.Get(connCtx) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - return - } - if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - - if p.Size() != 4 { - t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size()) - } - - var wg sync.WaitGroup - for i := 0; i < 4; i++ { - wg.Add(1) - go func() { - defer wg.Done() - wgconn, err := p.Get(connCtx) - if err != nil { - t.Errorf("failed to get new connection from pool: %s", err) - } - if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { - t.Errorf("failed to write quit command to first connection: %s", err) - } - }() - } - wg.Wait() - - if p.Size() != 0 { - t.Errorf("Get error. Expecting 0, got %d", p.Size()) - } - - connCancel() - _, err = p.Get(connCtx) - if err == nil { - t.Errorf("getting new connection on canceled context should fail, but didn't") - } - p.Close() -} - -func newConnPool(port int) (Pool, error) { - netDialer := net.Dialer{} - return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp", - fmt.Sprintf("127.0.0.1:%d", port)) -}