From c8684886ed23d2a42b1ab178400c7a7633650e47 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 23 Sep 2024 13:44:03 +0200 Subject: [PATCH] Refactor Get method to include context argument Updated the Get method in connpool.go and its usage in tests to include a context argument for better cancellation and timeout handling. Removed the redundant dialContext field from the connection pool struct and added a new test to validate context timeout behavior. --- connpool.go | 27 ++++++-------- connpool_test.go | 91 ++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 87 insertions(+), 31 deletions(-) diff --git a/connpool.go b/connpool.go index d875793..f50e73d 100644 --- a/connpool.go +++ b/connpool.go @@ -31,7 +31,7 @@ type Pool interface { // Get returns a new connection from the pool. Closing the connections returns // it back into the Pool. Closing a connection when the Pool is destroyed or // full will be counted as an error. - Get() (net.Conn, error) + Get(ctx context.Context) (net.Conn, error) // Close closes the pool and all its connections. After Close() the pool is // no longer usable. @@ -50,8 +50,6 @@ type connPool struct { // dialCtxFunc represents the actual net.Conn returned by the DialContextFunc. dialCtxFunc DialContextFunc - // dialContext is the context used for dialing new network connections within the connection pool. - dialContext context.Context // dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in // the connection pool. dialNetwork string @@ -96,7 +94,8 @@ func (c *PoolConn) MarkUnusable() { // new connection available in the pool, a new connection will be created via // the corresponding DialContextFunc() method. func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc, - network, address string) (Pool, error) { + network, address string, +) (Pool, error) { if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { return nil, ErrPoolInvalidCap } @@ -104,7 +103,6 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo pool := &connPool{ conns: make(chan net.Conn, maxCap), dialCtxFunc: dialCtxFunc, - dialContext: ctx, dialAddress: address, dialNetwork: network, } @@ -114,10 +112,9 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo conn, err := dialCtxFunc(ctx, network, address) if err != nil { pool.Close() - return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %s", err) + return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %w", err) } pool.conns <- conn - } return pool, nil @@ -126,8 +123,8 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo // Get satisfies the Get() method of the Pool inteface. If there is no new // connection available in the Pool, a new connection will be created via the // DialContextFunc() method. -func (p *connPool) Get() (net.Conn, error) { - ctx, conns, dialCtxFunc := p.getConnsAndDialContext() +func (p *connPool) Get(ctx context.Context) (net.Conn, error) { + conns, dialCtxFunc := p.getConnsAndDialContext() if conns == nil { return nil, ErrClosed } @@ -136,7 +133,7 @@ func (p *connPool) Get() (net.Conn, error) { // connections back to the pool select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, fmt.Errorf("failed to get connection: %w", ctx.Err()) case conn := <-conns: if conn == nil { return nil, ErrClosed @@ -145,7 +142,7 @@ func (p *connPool) Get() (net.Conn, error) { default: conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress) if err != nil { - return nil, err + return nil, fmt.Errorf("dialContextFunc failed: %w", err) } return p.wrapConn(conn), nil } @@ -158,7 +155,6 @@ func (p *connPool) Close() { conns := p.conns p.conns = nil p.dialCtxFunc = nil - p.dialContext = nil p.dialAddress = "" p.dialNetwork = "" p.mutex.Unlock() @@ -175,19 +171,18 @@ func (p *connPool) Close() { // Size returns the current number of connections in the connection pool. func (p *connPool) Size() int { - _, conns, _ := p.getConnsAndDialContext() + conns, _ := p.getConnsAndDialContext() return len(conns) } // getConnsAndDialContext returns the connection channel and the DialContext function for the // connection pool. -func (p *connPool) getConnsAndDialContext() (context.Context, chan net.Conn, DialContextFunc) { +func (p *connPool) getConnsAndDialContext() (chan net.Conn, DialContextFunc) { p.mutex.RLock() conns := p.conns dialCtxFunc := p.dialCtxFunc - ctx := p.dialContext p.mutex.RUnlock() - return ctx, conns, dialCtxFunc + return conns, dialCtxFunc } // put puts a passed connection back into the pool. If the pool is full or closed, diff --git a/connpool_test.go b/connpool_test.go index 4450e8c..1f42ff6 100644 --- a/connpool_test.go +++ b/connpool_test.go @@ -39,7 +39,7 @@ func TestNewConnPool(t *testing.T) { if pool.Size() != 5 { t.Errorf("expected 5 connections, got %d", pool.Size()) } - conn, err := pool.Get() + conn, err := pool.Get(context.Background()) if err != nil { t.Errorf("failed to get connection: %s", err) } @@ -68,7 +68,7 @@ func TestConnPool_Get_Type(t *testing.T) { } defer pool.Close() - conn, err := pool.Get() + conn, err := pool.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) return @@ -101,7 +101,7 @@ func TestConnPool_Get(t *testing.T) { p, _ := newConnPool(serverPort) defer p.Close() - conn, err := p.Get() + conn, err := p.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) return @@ -119,7 +119,7 @@ func TestConnPool_Get(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - wgconn, err := p.Get() + wgconn, err := p.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -134,7 +134,7 @@ func TestConnPool_Get(t *testing.T) { t.Errorf("Get error. Expecting 0, got %d", p.Size()) } - conn, err = p.Get() + conn, err = p.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -168,7 +168,7 @@ func TestPoolConn_Close(t *testing.T) { conns := make([]net.Conn, 30) for i := 0; i < 30; i++ { - conn, _ := p.Get() + conn, _ := p.Get(context.Background()) if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { t.Errorf("failed to write quit command to first connection: %s", err) } @@ -184,7 +184,7 @@ func TestPoolConn_Close(t *testing.T) { t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size()) } - conn, err := p.Get() + conn, err := p.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -218,7 +218,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) { pool, _ := newConnPool(serverPort) defer pool.Close() - conn, err := pool.Get() + conn, err := pool.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -227,7 +227,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) { } poolSize := pool.Size() - conn, err = pool.Get() + conn, err = pool.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -238,7 +238,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) { t.Errorf("pool size is expected to be equal to initial size") } - conn, err = pool.Get() + conn, err = pool.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -283,9 +283,6 @@ func TestConnPool_Close(t *testing.T) { if castPool.dialCtxFunc != nil { t.Error("closing pool failed: dialCtxFunc should be nil") } - if castPool.dialContext != nil { - t.Error("closing pool failed: dialContext should be nil") - } if castPool.dialAddress != "" { t.Error("closing pool failed: dialAddress should be empty") } @@ -293,7 +290,7 @@ func TestConnPool_Close(t *testing.T) { t.Error("closing pool failed: dialNetwork should be empty") } - conn, err := pool.Get() + conn, err := pool.Get(context.Background()) if err == nil { t.Errorf("closing pool failed: getting new connection should return an error") } @@ -332,7 +329,7 @@ func TestConnPool_Concurrency(t *testing.T) { getWg.Add(1) closeWg.Add(1) go func() { - conn, err := pool.Get() + conn, err := pool.Get(context.Background()) if err != nil { t.Errorf("failed to get new connection from pool: %s", err) } @@ -355,6 +352,70 @@ func TestConnPool_Concurrency(t *testing.T) { } } +func TestConnPool_GetContextTimeout(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverPort := TestServerPortBase + 17 + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 300) + + p, err := newConnPool(serverPort) + if err != nil { + t.Errorf("failed to create connection pool: %s", err) + } + defer p.Close() + + connCtx, connCancel := context.WithCancel(context.Background()) + defer connCancel() + + conn, err := p.Get(connCtx) + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + return + } + if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } + + if p.Size() != 4 { + t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size()) + } + + var wg sync.WaitGroup + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + wgconn, err := p.Get(connCtx) + if err != nil { + t.Errorf("failed to get new connection from pool: %s", err) + } + if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { + t.Errorf("failed to write quit command to first connection: %s", err) + } + }() + } + wg.Wait() + + if p.Size() != 0 { + t.Errorf("Get error. Expecting 0, got %d", p.Size()) + } + + connCancel() + _, err = p.Get(connCtx) + if err == nil { + t.Errorf("getting new connection on canceled context should fail, but didn't") + } + p.Close() +} + func newConnPool(port int) (Pool, error) { netDialer := net.Dialer{} return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp",