Compare commits

...

10 commits

Author SHA1 Message Date
07e7b17ae8
Add tests for ConnPool close and concurrency issues
This commit introduces two new tests: `TestConnPool_Close` and `TestConnPool_Concurrency`. The former ensures the proper closing of connection pool resources, while the latter checks for concurrency issues by creating and closing multiple connections in parallel.
2024-09-23 11:45:22 +02:00
2abdee743d
Add unit test for marking connection as unusable
Introduces `TestPoolConn_MarkUnusable` to ensure the pool maintains its integrity when a connection is marked unusable. This test validates that the connection pool size adjusts correctly after marking a connection as unusable and closing it.
2024-09-23 11:27:05 +02:00
5503be8451
Remove redundant error print statements
Removed redundant fmt.Printf error print statements for connection read and write errors. This cleans up the test output and makes error handling more streamlined.
2024-09-23 11:26:54 +02:00
774925078a
Improve error handling in connection pool tests
Add error checks to Close() calls in connpool_test.go to ensure connection closures are handled properly, with descriptive error messages. Update comment in connpool.go to improve clarity on the source of code inspiration.
2024-09-23 11:17:58 +02:00
f1188bdad7
Add test for PoolConn Close method
This commit introduces a new test, `TestPoolConn_Close`, to verify that connections are correctly closed and returned to the pool. It sets up a simple SMTP server, creates a connection pool, tests writing to and closing connections, and checks the pool size to ensure proper behavior.
2024-09-23 11:15:56 +02:00
2cbd0c4aef
Add test cases for connection pool functionality
Added new test cases `TestConnPool_Get_Type` and `TestConnPool_Get` to verify connection pool operations. These tests ensure proper connection type and handling of pool size after connection retrieval and usage.
2024-09-23 11:09:11 +02:00
9a9e0c936d
Remove redundant error handling in test code
The check for io.EOF and the associated print statement were unnecessary because the loop breaks on any error. This change simplifies the error handling logic in the `client_test.go` file and avoids redundant code.
2024-09-23 11:09:03 +02:00
33d4eb5b21
Add unit tests for connection pool and rename Len to Size
Introduced unit tests for the connection pool to ensure robust functionality. Also, renamed the Len method to Size in the Pool interface and its implementation for better clarity and consistency.
2024-09-23 10:33:06 +02:00
d6e5034bba
Add new error handling and connection management in connpool
Introduce ErrClosed and ErrNilConn errors for better error handling. Implement Close and MarkUnusable methods for improved connection lifecycle management. Add put method to return connections to the pool or close them if necessary.
2024-09-23 10:09:38 +02:00
1394f1fc20
Add context management and error handling to connection pool
Introduced context support and enhanced error handling in connpool.go. Added detailed comments for better maintainability and introduced a wrapper for net.Conn to manage connection close behavior. The changes improve the robustness and clarity of the connection pool's operation.
2024-09-23 09:56:23 +02:00
3 changed files with 522 additions and 35 deletions

View file

@ -2085,7 +2085,6 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
data, err := reader.ReadString('\n') data, err := reader.ReadString('\n')
if err != nil { if err != nil {
fmt.Printf("unable to read from connection: %s\n", err)
return return
} }
if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") { if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") {
@ -2093,17 +2092,12 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
return return
} }
if err = writeLine("250-localhost.localdomain\r\n" + featureSet); err != nil { if err = writeLine("250-localhost.localdomain\r\n" + featureSet); err != nil {
fmt.Printf("unable to write to connection: %s\n", err)
return return
} }
for { for {
data, err = reader.ReadString('\n') data, err = reader.ReadString('\n')
if err != nil { if err != nil {
if errors.Is(err, io.EOF) {
break
}
fmt.Println("Error reading data:", err)
break break
} }

View file

@ -5,16 +5,23 @@
package mail package mail
import ( import (
"context"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
) )
// Parts of the connection pool code is forked from https://github.com/fatih/pool/ // Parts of the connection pool code is forked/took inspiration from https://github.com/fatih/pool/
// Thanks to Fatih Arslan and the project contributors for providing this great // Thanks to Fatih Arslan and the project contributors for providing this great concurrency template.
// concurrency template.
var ( var (
// ErrClosed is returned when an operation is attempted on a closed connection pool.
ErrClosed = errors.New("connection pool is closed")
// ErrNilConn is returned when a nil connection is passed back to the connection pool.
ErrNilConn = errors.New("connection is nil")
// ErrPoolInvalidCap is returned when the connection pool's capacity settings are
// invalid (e.g., initial capacity is negative).
ErrPoolInvalidCap = errors.New("invalid connection pool capacity settings") ErrPoolInvalidCap = errors.New("invalid connection pool capacity settings")
) )
@ -30,18 +37,56 @@ type Pool interface {
// no longer usable. // no longer usable.
Close() Close()
// Len returns the current number of connections of the pool. // Size returns the current number of connections of the pool.
Len() int Size() int
} }
// connPool implements the Pool interface // connPool implements the Pool interface
type connPool struct { type connPool struct {
// mutex is used to synchronize access to the connection pool to ensure thread-safe operations // mutex is used to synchronize access to the connection pool to ensure thread-safe operations.
mutex sync.RWMutex mutex sync.RWMutex
// conns is a channel used to manage and distribute net.Conn objects within the connection pool // conns is a channel used to manage and distribute net.Conn objects within the connection pool.
conns chan net.Conn conns chan net.Conn
// dialCtx represents the actual net.Conn returned by the DialContextFunc
dialCtx DialContextFunc // dialCtxFunc represents the actual net.Conn returned by the DialContextFunc.
dialCtxFunc DialContextFunc
// dialContext is the context used for dialing new network connections within the connection pool.
dialContext context.Context
// dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in
// the connection pool.
dialNetwork string
// dialAddress specifies the address used to establish network connections within the connection pool.
dialAddress string
}
// PoolConn is a wrapper around net.Conn to modify the the behavior of net.Conn's Close() method.
type PoolConn struct {
net.Conn
mutex sync.RWMutex
pool *connPool
unusable bool
}
// Close puts a given pool connection back to the pool instead of closing it.
func (c *PoolConn) Close() error {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.unusable {
if c.Conn != nil {
return c.Conn.Close()
}
return nil
}
return c.pool.put(c.Conn)
}
// MarkUnusable marks the connection not usable any more, to let the pool close it instead
// of returning it to pool.
func (c *PoolConn) MarkUnusable() {
c.mutex.Lock()
c.unusable = true
c.mutex.Unlock()
} }
// NewConnPool returns a new pool based on buffered channels with an initial // NewConnPool returns a new pool based on buffered channels with an initial
@ -50,40 +95,126 @@ type connPool struct {
// fill the Pool until a new Get() is called. During a Get(), if there is no // fill the Pool until a new Get() is called. During a Get(), if there is no
// new connection available in the pool, a new connection will be created via // new connection available in the pool, a new connection will be created via
// the corresponding DialContextFunc() method. // the corresponding DialContextFunc() method.
func NewConnPool(initialCap, maxCap int, dialCtxFunc DialContextFunc) (Pool, error) { func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc,
network, address string) (Pool, error) {
if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { if initialCap < 0 || maxCap <= 0 || initialCap > maxCap {
return nil, ErrPoolInvalidCap return nil, ErrPoolInvalidCap
} }
pool := &connPool{ pool := &connPool{
conns: make(chan net.Conn, maxCap), conns: make(chan net.Conn, maxCap),
dialCtx: dialCtxFunc, dialCtxFunc: dialCtxFunc,
dialContext: ctx,
dialAddress: address,
dialNetwork: network,
} }
// create initial connections, if something goes wrong, // Initial connections for the pool. Pool will be closed on connection error
// just close the pool error out.
for i := 0; i < initialCap; i++ { for i := 0; i < initialCap; i++ {
/* conn, err := dialCtxFunc(ctx, network, address)
conn, err := dialCtxFunc()
if err != nil { if err != nil {
pool.Close() pool.Close()
return nil, fmt.Errorf("factory is not able to fill the pool: %s", err) return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %s", err)
} }
c.conns <- conn pool.conns <- conn
*/
} }
return pool, nil return pool, nil
} }
func (c *connPool) Get() (net.Conn, error) { // Get satisfies the Get() method of the Pool inteface. If there is no new
return nil, nil // connection available in the Pool, a new connection will be created via the
// DialContextFunc() method.
func (p *connPool) Get() (net.Conn, error) {
ctx, conns, dialCtxFunc := p.getConnsAndDialContext()
if conns == nil {
return nil, ErrClosed
} }
func (c *connPool) Close() {
// wrap the connections into the custom net.Conn implementation that puts
// connections back to the pool
select {
case <-ctx.Done():
return nil, ctx.Err()
case conn := <-conns:
if conn == nil {
return nil, ErrClosed
}
return p.wrapConn(conn), nil
default:
conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress)
if err != nil {
return nil, err
}
return p.wrapConn(conn), nil
}
}
// Close terminates all connections in the pool and frees associated resources. Once closed,
// the pool is no longer usable.
func (p *connPool) Close() {
p.mutex.Lock()
conns := p.conns
p.conns = nil
p.dialCtxFunc = nil
p.dialContext = nil
p.dialAddress = ""
p.dialNetwork = ""
p.mutex.Unlock()
if conns == nil {
return return
} }
func (c *connPool) Len() int { close(conns)
return 0 for conn := range conns {
_ = conn.Close()
}
}
// Size returns the current number of connections in the connection pool.
func (p *connPool) Size() int {
_, conns, _ := p.getConnsAndDialContext()
return len(conns)
}
// getConnsAndDialContext returns the connection channel and the DialContext function for the
// connection pool.
func (p *connPool) getConnsAndDialContext() (context.Context, chan net.Conn, DialContextFunc) {
p.mutex.RLock()
conns := p.conns
dialCtxFunc := p.dialCtxFunc
ctx := p.dialContext
p.mutex.RUnlock()
return ctx, conns, dialCtxFunc
}
// put puts a passed connection back into the pool. If the pool is full or closed,
// conn is simply closed. A nil conn will be rejected with an error.
func (p *connPool) put(conn net.Conn) error {
if conn == nil {
return ErrNilConn
}
p.mutex.RLock()
defer p.mutex.RUnlock()
if p.conns == nil {
return conn.Close()
}
select {
case p.conns <- conn:
return nil
default:
return conn.Close()
}
}
// wrapConn wraps a given net.Conn with a PoolConn, modifying the net.Conn's Close() method.
func (p *connPool) wrapConn(conn net.Conn) net.Conn {
poolconn := &PoolConn{pool: p}
poolconn.Conn = conn
return poolconn
} }

362
connpool_test.go Normal file
View file

@ -0,0 +1,362 @@
// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors
//
// SPDX-License-Identifier: MIT
package mail
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
)
func TestNewConnPool(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 10
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer pool.Close()
if pool == nil {
t.Errorf("connection pool is nil")
return
}
if pool.Size() != 5 {
t.Errorf("expected 5 connections, got %d", pool.Size())
}
conn, err := pool.Get()
if err != nil {
t.Errorf("failed to get connection: %s", err)
}
if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}
func TestConnPool_Get_Type(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 11
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer pool.Close()
conn, err := pool.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
return
}
_, ok := conn.(*PoolConn)
if !ok {
t.Error("received connection from pool is not of type PoolConn")
return
}
if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}
func TestConnPool_Get(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 12
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
p, _ := newConnPool(serverPort)
defer p.Close()
conn, err := p.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
return
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
if p.Size() != 4 {
t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size())
}
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
wgconn, err := p.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}()
}
wg.Wait()
if p.Size() != 0 {
t.Errorf("Get error. Expecting 0, got %d", p.Size())
}
conn, err = p.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
p.Close()
}
func TestPoolConn_Close(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 13
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
netDialer := net.Dialer{}
p, err := NewConnPool(context.Background(), 0, 30, netDialer.DialContext, "tcp",
fmt.Sprintf("127.0.0.1:%d", serverPort))
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer p.Close()
conns := make([]net.Conn, 30)
for i := 0; i < 30; i++ {
conn, _ := p.Get()
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
conns[i] = conn
}
for _, conn := range conns {
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
}
if p.Size() != 30 {
t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size())
}
conn, err := p.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
p.Close()
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
if p.Size() != 0 {
t.Errorf("closed pool shouldn't allow to put connections.")
}
}
func TestPoolConn_MarkUnusable(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 14
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, _ := newConnPool(serverPort)
defer pool.Close()
conn, err := pool.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
poolSize := pool.Size()
conn, err = pool.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
if pool.Size() != poolSize {
t.Errorf("pool size is expected to be equal to initial size")
}
conn, err = pool.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if pc, ok := conn.(*PoolConn); !ok {
t.Errorf("this should never happen")
} else {
pc.MarkUnusable()
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
if pool.Size() != poolSize-1 {
t.Errorf("pool size is expected to be: %d but got: %d", poolSize-1, pool.Size())
}
}
func TestConnPool_Close(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 15
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
pool.Close()
castPool := pool.(*connPool)
if castPool.conns != nil {
t.Error("closing pool failed: conns channel should be nil")
}
if castPool.dialCtxFunc != nil {
t.Error("closing pool failed: dialCtxFunc should be nil")
}
if castPool.dialContext != nil {
t.Error("closing pool failed: dialContext should be nil")
}
if castPool.dialAddress != "" {
t.Error("closing pool failed: dialAddress should be empty")
}
if castPool.dialNetwork != "" {
t.Error("closing pool failed: dialNetwork should be empty")
}
conn, err := pool.Get()
if err == nil {
t.Errorf("closing pool failed: getting new connection should return an error")
}
if conn != nil {
t.Errorf("closing pool failed: getting new connection should return a nil-connection")
}
if pool.Size() != 0 {
t.Errorf("closing pool failed: pool size should be 0, but got: %d", pool.Size())
}
}
func TestConnPool_Concurrency(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 16
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer pool.Close()
pipe := make(chan net.Conn)
getWg := sync.WaitGroup{}
closeWg := sync.WaitGroup{}
for i := 0; i < 30; i++ {
getWg.Add(1)
closeWg.Add(1)
go func() {
conn, err := pool.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
pipe <- conn
getWg.Done()
}()
go func() {
conn := <-pipe
if conn == nil {
return
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
closeWg.Done()
}()
getWg.Wait()
closeWg.Wait()
}
}
func newConnPool(port int) (Pool, error) {
netDialer := net.Dialer{}
return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp",
fmt.Sprintf("127.0.0.1:%d", port))
}