mirror of
https://github.com/wneessen/go-mail.git
synced 2024-11-23 14:10:50 +01:00
Compare commits
3 commits
07e7b17ae8
...
fd115d5173
Author | SHA1 | Date | |
---|---|---|---|
fd115d5173 | |||
4f6224131e | |||
c8684886ed |
3 changed files with 88 additions and 32 deletions
27
connpool.go
27
connpool.go
|
@ -31,7 +31,7 @@ type Pool interface {
|
||||||
// Get returns a new connection from the pool. Closing the connections returns
|
// Get returns a new connection from the pool. Closing the connections returns
|
||||||
// it back into the Pool. Closing a connection when the Pool is destroyed or
|
// it back into the Pool. Closing a connection when the Pool is destroyed or
|
||||||
// full will be counted as an error.
|
// full will be counted as an error.
|
||||||
Get() (net.Conn, error)
|
Get(ctx context.Context) (net.Conn, error)
|
||||||
|
|
||||||
// Close closes the pool and all its connections. After Close() the pool is
|
// Close closes the pool and all its connections. After Close() the pool is
|
||||||
// no longer usable.
|
// no longer usable.
|
||||||
|
@ -50,8 +50,6 @@ type connPool struct {
|
||||||
|
|
||||||
// dialCtxFunc represents the actual net.Conn returned by the DialContextFunc.
|
// dialCtxFunc represents the actual net.Conn returned by the DialContextFunc.
|
||||||
dialCtxFunc DialContextFunc
|
dialCtxFunc DialContextFunc
|
||||||
// dialContext is the context used for dialing new network connections within the connection pool.
|
|
||||||
dialContext context.Context
|
|
||||||
// dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in
|
// dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in
|
||||||
// the connection pool.
|
// the connection pool.
|
||||||
dialNetwork string
|
dialNetwork string
|
||||||
|
@ -96,7 +94,8 @@ func (c *PoolConn) MarkUnusable() {
|
||||||
// new connection available in the pool, a new connection will be created via
|
// new connection available in the pool, a new connection will be created via
|
||||||
// the corresponding DialContextFunc() method.
|
// the corresponding DialContextFunc() method.
|
||||||
func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc,
|
func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc,
|
||||||
network, address string) (Pool, error) {
|
network, address string,
|
||||||
|
) (Pool, error) {
|
||||||
if initialCap < 0 || maxCap <= 0 || initialCap > maxCap {
|
if initialCap < 0 || maxCap <= 0 || initialCap > maxCap {
|
||||||
return nil, ErrPoolInvalidCap
|
return nil, ErrPoolInvalidCap
|
||||||
}
|
}
|
||||||
|
@ -104,7 +103,6 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo
|
||||||
pool := &connPool{
|
pool := &connPool{
|
||||||
conns: make(chan net.Conn, maxCap),
|
conns: make(chan net.Conn, maxCap),
|
||||||
dialCtxFunc: dialCtxFunc,
|
dialCtxFunc: dialCtxFunc,
|
||||||
dialContext: ctx,
|
|
||||||
dialAddress: address,
|
dialAddress: address,
|
||||||
dialNetwork: network,
|
dialNetwork: network,
|
||||||
}
|
}
|
||||||
|
@ -114,10 +112,9 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo
|
||||||
conn, err := dialCtxFunc(ctx, network, address)
|
conn, err := dialCtxFunc(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pool.Close()
|
pool.Close()
|
||||||
return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %s", err)
|
return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %w", err)
|
||||||
}
|
}
|
||||||
pool.conns <- conn
|
pool.conns <- conn
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return pool, nil
|
return pool, nil
|
||||||
|
@ -126,8 +123,8 @@ func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialCo
|
||||||
// Get satisfies the Get() method of the Pool inteface. If there is no new
|
// Get satisfies the Get() method of the Pool inteface. If there is no new
|
||||||
// connection available in the Pool, a new connection will be created via the
|
// connection available in the Pool, a new connection will be created via the
|
||||||
// DialContextFunc() method.
|
// DialContextFunc() method.
|
||||||
func (p *connPool) Get() (net.Conn, error) {
|
func (p *connPool) Get(ctx context.Context) (net.Conn, error) {
|
||||||
ctx, conns, dialCtxFunc := p.getConnsAndDialContext()
|
conns, dialCtxFunc := p.getConnsAndDialContext()
|
||||||
if conns == nil {
|
if conns == nil {
|
||||||
return nil, ErrClosed
|
return nil, ErrClosed
|
||||||
}
|
}
|
||||||
|
@ -136,7 +133,7 @@ func (p *connPool) Get() (net.Conn, error) {
|
||||||
// connections back to the pool
|
// connections back to the pool
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, ctx.Err()
|
return nil, fmt.Errorf("failed to get connection: %w", ctx.Err())
|
||||||
case conn := <-conns:
|
case conn := <-conns:
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return nil, ErrClosed
|
return nil, ErrClosed
|
||||||
|
@ -145,7 +142,7 @@ func (p *connPool) Get() (net.Conn, error) {
|
||||||
default:
|
default:
|
||||||
conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress)
|
conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("dialContextFunc failed: %w", err)
|
||||||
}
|
}
|
||||||
return p.wrapConn(conn), nil
|
return p.wrapConn(conn), nil
|
||||||
}
|
}
|
||||||
|
@ -158,7 +155,6 @@ func (p *connPool) Close() {
|
||||||
conns := p.conns
|
conns := p.conns
|
||||||
p.conns = nil
|
p.conns = nil
|
||||||
p.dialCtxFunc = nil
|
p.dialCtxFunc = nil
|
||||||
p.dialContext = nil
|
|
||||||
p.dialAddress = ""
|
p.dialAddress = ""
|
||||||
p.dialNetwork = ""
|
p.dialNetwork = ""
|
||||||
p.mutex.Unlock()
|
p.mutex.Unlock()
|
||||||
|
@ -175,19 +171,18 @@ func (p *connPool) Close() {
|
||||||
|
|
||||||
// Size returns the current number of connections in the connection pool.
|
// Size returns the current number of connections in the connection pool.
|
||||||
func (p *connPool) Size() int {
|
func (p *connPool) Size() int {
|
||||||
_, conns, _ := p.getConnsAndDialContext()
|
conns, _ := p.getConnsAndDialContext()
|
||||||
return len(conns)
|
return len(conns)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getConnsAndDialContext returns the connection channel and the DialContext function for the
|
// getConnsAndDialContext returns the connection channel and the DialContext function for the
|
||||||
// connection pool.
|
// connection pool.
|
||||||
func (p *connPool) getConnsAndDialContext() (context.Context, chan net.Conn, DialContextFunc) {
|
func (p *connPool) getConnsAndDialContext() (chan net.Conn, DialContextFunc) {
|
||||||
p.mutex.RLock()
|
p.mutex.RLock()
|
||||||
conns := p.conns
|
conns := p.conns
|
||||||
dialCtxFunc := p.dialCtxFunc
|
dialCtxFunc := p.dialCtxFunc
|
||||||
ctx := p.dialContext
|
|
||||||
p.mutex.RUnlock()
|
p.mutex.RUnlock()
|
||||||
return ctx, conns, dialCtxFunc
|
return conns, dialCtxFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// put puts a passed connection back into the pool. If the pool is full or closed,
|
// put puts a passed connection back into the pool. If the pool is full or closed,
|
||||||
|
|
|
@ -39,7 +39,7 @@ func TestNewConnPool(t *testing.T) {
|
||||||
if pool.Size() != 5 {
|
if pool.Size() != 5 {
|
||||||
t.Errorf("expected 5 connections, got %d", pool.Size())
|
t.Errorf("expected 5 connections, got %d", pool.Size())
|
||||||
}
|
}
|
||||||
conn, err := pool.Get()
|
conn, err := pool.Get(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get connection: %s", err)
|
t.Errorf("failed to get connection: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ func TestConnPool_Get_Type(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer pool.Close()
|
defer pool.Close()
|
||||||
|
|
||||||
conn, err := pool.Get()
|
conn, err := pool.Get(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
return
|
return
|
||||||
|
@ -101,7 +101,7 @@ func TestConnPool_Get(t *testing.T) {
|
||||||
p, _ := newConnPool(serverPort)
|
p, _ := newConnPool(serverPort)
|
||||||
defer p.Close()
|
defer p.Close()
|
||||||
|
|
||||||
conn, err := p.Get()
|
conn, err := p.Get(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
return
|
return
|
||||||
|
@ -119,7 +119,7 @@ func TestConnPool_Get(t *testing.T) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
wgconn, err := p.Get()
|
wgconn, err := p.Get(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -134,7 +134,7 @@ func TestConnPool_Get(t *testing.T) {
|
||||||
t.Errorf("Get error. Expecting 0, got %d", p.Size())
|
t.Errorf("Get error. Expecting 0, got %d", p.Size())
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err = p.Get()
|
conn, err = p.Get(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -168,7 +168,7 @@ func TestPoolConn_Close(t *testing.T) {
|
||||||
|
|
||||||
conns := make([]net.Conn, 30)
|
conns := make([]net.Conn, 30)
|
||||||
for i := 0; i < 30; i++ {
|
for i := 0; i < 30; i++ {
|
||||||
conn, _ := p.Get()
|
conn, _ := p.Get(context.Background())
|
||||||
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
||||||
t.Errorf("failed to write quit command to first connection: %s", err)
|
t.Errorf("failed to write quit command to first connection: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -184,7 +184,7 @@ func TestPoolConn_Close(t *testing.T) {
|
||||||
t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size())
|
t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size())
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := p.Get()
|
conn, err := p.Get(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -218,7 +218,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) {
|
||||||
pool, _ := newConnPool(serverPort)
|
pool, _ := newConnPool(serverPort)
|
||||||
defer pool.Close()
|
defer pool.Close()
|
||||||
|
|
||||||
conn, err := pool.Get()
|
conn, err := pool.Get(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -227,7 +227,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
poolSize := pool.Size()
|
poolSize := pool.Size()
|
||||||
conn, err = pool.Get()
|
conn, err = pool.Get(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -238,7 +238,7 @@ func TestPoolConn_MarkUnusable(t *testing.T) {
|
||||||
t.Errorf("pool size is expected to be equal to initial size")
|
t.Errorf("pool size is expected to be equal to initial size")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err = pool.Get()
|
conn, err = pool.Get(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -283,9 +283,6 @@ func TestConnPool_Close(t *testing.T) {
|
||||||
if castPool.dialCtxFunc != nil {
|
if castPool.dialCtxFunc != nil {
|
||||||
t.Error("closing pool failed: dialCtxFunc should be nil")
|
t.Error("closing pool failed: dialCtxFunc should be nil")
|
||||||
}
|
}
|
||||||
if castPool.dialContext != nil {
|
|
||||||
t.Error("closing pool failed: dialContext should be nil")
|
|
||||||
}
|
|
||||||
if castPool.dialAddress != "" {
|
if castPool.dialAddress != "" {
|
||||||
t.Error("closing pool failed: dialAddress should be empty")
|
t.Error("closing pool failed: dialAddress should be empty")
|
||||||
}
|
}
|
||||||
|
@ -293,7 +290,7 @@ func TestConnPool_Close(t *testing.T) {
|
||||||
t.Error("closing pool failed: dialNetwork should be empty")
|
t.Error("closing pool failed: dialNetwork should be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pool.Get()
|
conn, err := pool.Get(context.Background())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("closing pool failed: getting new connection should return an error")
|
t.Errorf("closing pool failed: getting new connection should return an error")
|
||||||
}
|
}
|
||||||
|
@ -332,7 +329,7 @@ func TestConnPool_Concurrency(t *testing.T) {
|
||||||
getWg.Add(1)
|
getWg.Add(1)
|
||||||
closeWg.Add(1)
|
closeWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := pool.Get()
|
conn, err := pool.Get(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -355,6 +352,70 @@ func TestConnPool_Concurrency(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnPool_GetContextCancel(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
serverPort := TestServerPortBase + 17
|
||||||
|
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
||||||
|
go func() {
|
||||||
|
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
|
||||||
|
t.Errorf("failed to start test server: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
time.Sleep(time.Millisecond * 300)
|
||||||
|
|
||||||
|
p, err := newConnPool(serverPort)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to create connection pool: %s", err)
|
||||||
|
}
|
||||||
|
defer p.Close()
|
||||||
|
|
||||||
|
connCtx, connCancel := context.WithCancel(context.Background())
|
||||||
|
defer connCancel()
|
||||||
|
|
||||||
|
conn, err := p.Get(connCtx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
||||||
|
t.Errorf("failed to write quit command to first connection: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Size() != 4 {
|
||||||
|
t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
wgconn, err := p.Get(connCtx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to get new connection from pool: %s", err)
|
||||||
|
}
|
||||||
|
if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
||||||
|
t.Errorf("failed to write quit command to first connection: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if p.Size() != 0 {
|
||||||
|
t.Errorf("Get error. Expecting 0, got %d", p.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
connCancel()
|
||||||
|
_, err = p.Get(connCtx)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("getting new connection on canceled context should fail, but didn't")
|
||||||
|
}
|
||||||
|
p.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func newConnPool(port int) (Pool, error) {
|
func newConnPool(port int) (Pool, error) {
|
||||||
netDialer := net.Dialer{}
|
netDialer := net.Dialer{}
|
||||||
return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp",
|
return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp",
|
||||||
|
|
|
@ -22,7 +22,7 @@ import "strings"
|
||||||
// should be the preferred greeting for servers that support it.
|
// should be the preferred greeting for servers that support it.
|
||||||
//
|
//
|
||||||
// Backport of: https://github.com/golang/go/commit/4d8db00641cc9ff4f44de7df9b8c4f4a4f9416ee#diff-4f6f6bdb9891d4dd271f9f31430420a2e44018fe4ee539576faf458bebb3cee4
|
// Backport of: https://github.com/golang/go/commit/4d8db00641cc9ff4f44de7df9b8c4f4a4f9416ee#diff-4f6f6bdb9891d4dd271f9f31430420a2e44018fe4ee539576faf458bebb3cee4
|
||||||
// to guarantee backwards compatibility with Go 1.16/1.17:w
|
// to guarantee backwards compatibility with Go 1.16/1.17
|
||||||
func (c *Client) ehlo() error {
|
func (c *Client) ehlo() error {
|
||||||
_, msg, err := c.cmd(250, "EHLO %s", c.localName)
|
_, msg, err := c.cmd(250, "EHLO %s", c.localName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in a new issue