Refactor Get method to include context argument

Updated the Get method in connpool.go and its usage in tests to include a context argument for better cancellation and timeout handling. Removed the redundant dialContext field from the connection pool struct and added a new test to validate context timeout behavior.
This commit is contained in:
Winni Neessen 2024-09-23 13:44:03 +02:00
parent 07e7b17ae8
commit c8684886ed
Signed by: wneessen
GPG key ID: 385AC9889632126E
2 changed files with 87 additions and 31 deletions

View file

@ -31,7 +31,7 @@ type Pool interface {
// Get returns a new connection from the pool. Closing the connections returns // Get returns a new connection from the pool. Closing the connections returns
// it back into the Pool. Closing a connection when the Pool is destroyed or // it back into the Pool. Closing a connection when the Pool is destroyed or
// full will be counted as an error. // full will be counted as an error.
Get() (net.Conn, error) Get(ctx context.Context) (net.Conn, error)
// Close closes the pool and all its connections. After Close() the pool is // Close closes the pool and all its connections. After Close() the pool is
// no longer usable. // no longer usable.
@ -50,8 +50,6 @@ type connPool struct {
// dialCtxFunc represents the actual net.Conn returned by the DialContextFunc. // dialCtxFunc represents the actual net.Conn returned by the DialContextFunc.
dialCtxFunc DialContextFunc dialCtxFunc DialContextFunc
// dialContext is the context used for dialing new network connections within the connection pool.
dialContext context.Context
// dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in // dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in
// the connection pool. // the connection pool.
dialNetwork string dialNetwork string
@ -96,7 +94,8 @@ func (c *PoolConn) MarkUnusable() {
// new connection available in the pool, a new connection will be created via // new connection available in the pool, a new connection will be created via
// the corresponding DialContextFunc() method. // the corresponding DialContextFunc() method.
func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc, func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc,
network, address string) (Pool, error) { network, address string,
) (Pool, error) {
if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { if initialCap < 0 || maxCap <= 0 || initialCap > maxCap {
return nil, ErrPoolInvalidCap return nil, ErrPoolInvalidCap
} }
@ -104,7 +103,6 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo
pool := &connPool{ pool := &connPool{
conns: make(chan net.Conn, maxCap), conns: make(chan net.Conn, maxCap),
dialCtxFunc: dialCtxFunc, dialCtxFunc: dialCtxFunc,
dialContext: ctx,
dialAddress: address, dialAddress: address,
dialNetwork: network, dialNetwork: network,
} }
@ -114,10 +112,9 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo
conn, err := dialCtxFunc(ctx, network, address) conn, err := dialCtxFunc(ctx, network, address)
if err != nil { if err != nil {
pool.Close() pool.Close()
return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %s", err) return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %w", err)
} }
pool.conns <- conn pool.conns <- conn
} }
return pool, nil return pool, nil
@ -126,8 +123,8 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo
// Get satisfies the Get() method of the Pool inteface. If there is no new // Get satisfies the Get() method of the Pool inteface. If there is no new
// connection available in the Pool, a new connection will be created via the // connection available in the Pool, a new connection will be created via the
// DialContextFunc() method. // DialContextFunc() method.
func (p *connPool) Get() (net.Conn, error) { func (p *connPool) Get(ctx context.Context) (net.Conn, error) {
ctx, conns, dialCtxFunc := p.getConnsAndDialContext() conns, dialCtxFunc := p.getConnsAndDialContext()
if conns == nil { if conns == nil {
return nil, ErrClosed return nil, ErrClosed
} }
@ -136,7 +133,7 @@ func (p *connPool) Get() (net.Conn, error) {
// connections back to the pool // connections back to the pool
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, fmt.Errorf("failed to get connection: %w", ctx.Err())
case conn := <-conns: case conn := <-conns:
if conn == nil { if conn == nil {
return nil, ErrClosed return nil, ErrClosed
@ -145,7 +142,7 @@ func (p *connPool) Get() (net.Conn, error) {
default: default:
conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress) conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("dialContextFunc failed: %w", err)
} }
return p.wrapConn(conn), nil return p.wrapConn(conn), nil
} }
@ -158,7 +155,6 @@ func (p *connPool) Close() {
conns := p.conns conns := p.conns
p.conns = nil p.conns = nil
p.dialCtxFunc = nil p.dialCtxFunc = nil
p.dialContext = nil
p.dialAddress = "" p.dialAddress = ""
p.dialNetwork = "" p.dialNetwork = ""
p.mutex.Unlock() p.mutex.Unlock()
@ -175,19 +171,18 @@ func (p *connPool) Close() {
// Size returns the current number of connections in the connection pool. // Size returns the current number of connections in the connection pool.
func (p *connPool) Size() int { func (p *connPool) Size() int {
_, conns, _ := p.getConnsAndDialContext() conns, _ := p.getConnsAndDialContext()
return len(conns) return len(conns)
} }
// getConnsAndDialContext returns the connection channel and the DialContext function for the // getConnsAndDialContext returns the connection channel and the DialContext function for the
// connection pool. // connection pool.
func (p *connPool) getConnsAndDialContext() (context.Context, chan net.Conn, DialContextFunc) { func (p *connPool) getConnsAndDialContext() (chan net.Conn, DialContextFunc) {
p.mutex.RLock() p.mutex.RLock()
conns := p.conns conns := p.conns
dialCtxFunc := p.dialCtxFunc dialCtxFunc := p.dialCtxFunc
ctx := p.dialContext
p.mutex.RUnlock() p.mutex.RUnlock()
return ctx, conns, dialCtxFunc return conns, dialCtxFunc
} }
// put puts a passed connection back into the pool. If the pool is full or closed, // put puts a passed connection back into the pool. If the pool is full or closed,

View file

@ -39,7 +39,7 @@ func TestNewConnPool(t *testing.T) {
if pool.Size() != 5 { if pool.Size() != 5 {
t.Errorf("expected 5 connections, got %d", pool.Size()) t.Errorf("expected 5 connections, got %d", pool.Size())
} }
conn, err := pool.Get() conn, err := pool.Get(context.Background())
if err != nil { if err != nil {
t.Errorf("failed to get connection: %s", err) t.Errorf("failed to get connection: %s", err)
} }
@ -68,7 +68,7 @@ func TestConnPool_Get_Type(t *testing.T) {
} }
defer pool.Close() defer pool.Close()
conn, err := pool.Get() conn, err := pool.Get(context.Background())
if err != nil { if err != nil {
t.Errorf("failed to get new connection from pool: %s", err) t.Errorf("failed to get new connection from pool: %s", err)
return return
@ -101,7 +101,7 @@ func TestConnPool_Get(t *testing.T) {
p, _ := newConnPool(serverPort) p, _ := newConnPool(serverPort)
defer p.Close() defer p.Close()
conn, err := p.Get() conn, err := p.Get(context.Background())
if err != nil { if err != nil {
t.Errorf("failed to get new connection from pool: %s", err) t.Errorf("failed to get new connection from pool: %s", err)
return return
@ -119,7 +119,7 @@ func TestConnPool_Get(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
wgconn, err := p.Get() wgconn, err := p.Get(context.Background())
if err != nil { if err != nil {
t.Errorf("failed to get new connection from pool: %s", err) t.Errorf("failed to get new connection from pool: %s", err)
} }
@ -134,7 +134,7 @@ func TestConnPool_Get(t *testing.T) {
t.Errorf("Get error. Expecting 0, got %d", p.Size()) t.Errorf("Get error. Expecting 0, got %d", p.Size())
} }
conn, err = p.Get() conn, err = p.Get(context.Background())
if err != nil { if err != nil {
t.Errorf("failed to get new connection from pool: %s", err) t.Errorf("failed to get new connection from pool: %s", err)
} }
@ -168,7 +168,7 @@ func TestPoolConn_Close(t *testing.T) {
conns := make([]net.Conn, 30) conns := make([]net.Conn, 30)
for i := 0; i < 30; i++ { for i := 0; i < 30; i++ {
conn, _ := p.Get() conn, _ := p.Get(context.Background())
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil { if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err) t.Errorf("failed to write quit command to first connection: %s", err)
} }
@ -184,7 +184,7 @@ func TestPoolConn_Close(t *testing.T) {
t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size()) t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size())
} }
conn, err := p.Get() conn, err := p.Get(context.Background())
if err != nil { if err != nil {
t.Errorf("failed to get new connection from pool: %s", err) t.Errorf("failed to get new connection from pool: %s", err)
} }
@ -218,7 +218,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) {
pool, _ := newConnPool(serverPort) pool, _ := newConnPool(serverPort)
defer pool.Close() defer pool.Close()
conn, err := pool.Get() conn, err := pool.Get(context.Background())
if err != nil { if err != nil {
t.Errorf("failed to get new connection from pool: %s", err) t.Errorf("failed to get new connection from pool: %s", err)
} }
@ -227,7 +227,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) {
} }
poolSize := pool.Size() poolSize := pool.Size()
conn, err = pool.Get() conn, err = pool.Get(context.Background())
if err != nil { if err != nil {
t.Errorf("failed to get new connection from pool: %s", err) t.Errorf("failed to get new connection from pool: %s", err)
} }
@ -238,7 +238,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) {
t.Errorf("pool size is expected to be equal to initial size") t.Errorf("pool size is expected to be equal to initial size")
} }
conn, err = pool.Get() conn, err = pool.Get(context.Background())
if err != nil { if err != nil {
t.Errorf("failed to get new connection from pool: %s", err) t.Errorf("failed to get new connection from pool: %s", err)
} }
@ -283,9 +283,6 @@ func TestConnPool_Close(t *testing.T) {
if castPool.dialCtxFunc != nil { if castPool.dialCtxFunc != nil {
t.Error("closing pool failed: dialCtxFunc should be nil") t.Error("closing pool failed: dialCtxFunc should be nil")
} }
if castPool.dialContext != nil {
t.Error("closing pool failed: dialContext should be nil")
}
if castPool.dialAddress != "" { if castPool.dialAddress != "" {
t.Error("closing pool failed: dialAddress should be empty") t.Error("closing pool failed: dialAddress should be empty")
} }
@ -293,7 +290,7 @@ func TestConnPool_Close(t *testing.T) {
t.Error("closing pool failed: dialNetwork should be empty") t.Error("closing pool failed: dialNetwork should be empty")
} }
conn, err := pool.Get() conn, err := pool.Get(context.Background())
if err == nil { if err == nil {
t.Errorf("closing pool failed: getting new connection should return an error") t.Errorf("closing pool failed: getting new connection should return an error")
} }
@ -332,7 +329,7 @@ func TestConnPool_Concurrency(t *testing.T) {
getWg.Add(1) getWg.Add(1)
closeWg.Add(1) closeWg.Add(1)
go func() { go func() {
conn, err := pool.Get() conn, err := pool.Get(context.Background())
if err != nil { if err != nil {
t.Errorf("failed to get new connection from pool: %s", err) t.Errorf("failed to get new connection from pool: %s", err)
} }
@ -355,6 +352,70 @@ func TestConnPool_Concurrency(t *testing.T) {
} }
} }
func TestConnPool_GetContextTimeout(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 17
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
p, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer p.Close()
connCtx, connCancel := context.WithCancel(context.Background())
defer connCancel()
conn, err := p.Get(connCtx)
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
return
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
if p.Size() != 4 {
t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size())
}
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
wgconn, err := p.Get(connCtx)
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}()
}
wg.Wait()
if p.Size() != 0 {
t.Errorf("Get error. Expecting 0, got %d", p.Size())
}
connCancel()
_, err = p.Get(connCtx)
if err == nil {
t.Errorf("getting new connection on canceled context should fail, but didn't")
}
p.Close()
}
func newConnPool(port int) (Pool, error) { func newConnPool(port int) (Pool, error) {
netDialer := net.Dialer{} netDialer := net.Dialer{}
return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp", return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp",