mirror of
https://github.com/wneessen/go-mail.git
synced 2024-11-23 14:10:50 +01:00
Compare commits
2 commits
fd115d5173
...
3871b2be44
Author | SHA1 | Date | |
---|---|---|---|
3871b2be44 | |||
8683917c3d |
5 changed files with 27 additions and 645 deletions
18
client.go
18
client.go
|
@ -12,6 +12,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/wneessen/go-mail/log"
|
"github.com/wneessen/go-mail/log"
|
||||||
|
@ -87,6 +88,7 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
|
||||||
|
|
||||||
// Client is the SMTP client struct
|
// Client is the SMTP client struct
|
||||||
type Client struct {
|
type Client struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
// connection is the net.Conn that the smtp.Client is based on
|
// connection is the net.Conn that the smtp.Client is based on
|
||||||
connection net.Conn
|
connection net.Conn
|
||||||
|
|
||||||
|
@ -589,6 +591,9 @@ func (c *Client) setDefaultHelo() error {
|
||||||
|
|
||||||
// DialWithContext establishes a connection to the SMTP server with a given context.Context
|
// DialWithContext establishes a connection to the SMTP server with a given context.Context
|
||||||
func (c *Client) DialWithContext(dialCtx context.Context) error {
|
func (c *Client) DialWithContext(dialCtx context.Context) error {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
|
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
@ -602,17 +607,16 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
|
||||||
c.dialContextFunc = tlsDialer.DialContext
|
c.dialContextFunc = tlsDialer.DialContext
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var err error
|
connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr())
|
||||||
c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
|
|
||||||
if err != nil && c.fallbackPort != 0 {
|
if err != nil && c.fallbackPort != 0 {
|
||||||
// TODO: should we somehow log or append the previous error?
|
// TODO: should we somehow log or append the previous error?
|
||||||
c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
|
connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := smtp.NewClient(c.connection, c.host)
|
client, err := smtp.NewClient(connection, c.host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -691,7 +695,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
|
||||||
// checkConn makes sure that a required server connection is available and extends the
|
// checkConn makes sure that a required server connection is available and extends the
|
||||||
// connection deadline
|
// connection deadline
|
||||||
func (c *Client) checkConn() error {
|
func (c *Client) checkConn() error {
|
||||||
if c.connection == nil {
|
if !c.smtpClient.HasConnection() {
|
||||||
return ErrNoActiveConnection
|
return ErrNoActiveConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -701,7 +705,7 @@ func (c *Client) checkConn() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil {
|
if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil {
|
||||||
return ErrDeadlineExtendFailed
|
return ErrDeadlineExtendFailed
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -715,7 +719,7 @@ func (c *Client) serverFallbackAddr() string {
|
||||||
|
|
||||||
// tls tries to make sure that the STARTTLS requirements are satisfied
|
// tls tries to make sure that the STARTTLS requirements are satisfied
|
||||||
func (c *Client) tls() error {
|
func (c *Client) tls() error {
|
||||||
if c.connection == nil {
|
if !c.smtpClient.HasConnection() {
|
||||||
return ErrNoActiveConnection
|
return ErrNoActiveConnection
|
||||||
}
|
}
|
||||||
if !c.useSSL && c.tlspolicy != NoTLS {
|
if !c.useSSL && c.tlspolicy != NoTLS {
|
||||||
|
|
|
@ -13,6 +13,8 @@ import (
|
||||||
|
|
||||||
// Send sends out the mail message
|
// Send sends out the mail message
|
||||||
func (c *Client) Send(messages ...*Msg) (returnErr error) {
|
func (c *Client) Send(messages ...*Msg) (returnErr error) {
|
||||||
|
c.mutex.Lock()
|
||||||
|
defer c.mutex.Unlock()
|
||||||
if err := c.checkConn(); err != nil {
|
if err := c.checkConn(); err != nil {
|
||||||
returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)}
|
returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)}
|
||||||
return
|
return
|
||||||
|
|
215
connpool.go
215
connpool.go
|
@ -1,215 +0,0 @@
|
||||||
// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
package mail
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Parts of the connection pool code is forked/took inspiration from https://github.com/fatih/pool/
|
|
||||||
// Thanks to Fatih Arslan and the project contributors for providing this great concurrency template.
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrClosed is returned when an operation is attempted on a closed connection pool.
|
|
||||||
ErrClosed = errors.New("connection pool is closed")
|
|
||||||
// ErrNilConn is returned when a nil connection is passed back to the connection pool.
|
|
||||||
ErrNilConn = errors.New("connection is nil")
|
|
||||||
// ErrPoolInvalidCap is returned when the connection pool's capacity settings are
|
|
||||||
// invalid (e.g., initial capacity is negative).
|
|
||||||
ErrPoolInvalidCap = errors.New("invalid connection pool capacity settings")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Pool interface describes a connection pool implementation. A Pool is
|
|
||||||
// thread-/go-routine safe.
|
|
||||||
type Pool interface {
|
|
||||||
// Get returns a new connection from the pool. Closing the connections returns
|
|
||||||
// it back into the Pool. Closing a connection when the Pool is destroyed or
|
|
||||||
// full will be counted as an error.
|
|
||||||
Get(ctx context.Context) (net.Conn, error)
|
|
||||||
|
|
||||||
// Close closes the pool and all its connections. After Close() the pool is
|
|
||||||
// no longer usable.
|
|
||||||
Close()
|
|
||||||
|
|
||||||
// Size returns the current number of connections of the pool.
|
|
||||||
Size() int
|
|
||||||
}
|
|
||||||
|
|
||||||
// connPool implements the Pool interface
|
|
||||||
type connPool struct {
|
|
||||||
// mutex is used to synchronize access to the connection pool to ensure thread-safe operations.
|
|
||||||
mutex sync.RWMutex
|
|
||||||
// conns is a channel used to manage and distribute net.Conn objects within the connection pool.
|
|
||||||
conns chan net.Conn
|
|
||||||
|
|
||||||
// dialCtxFunc represents the actual net.Conn returned by the DialContextFunc.
|
|
||||||
dialCtxFunc DialContextFunc
|
|
||||||
// dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in
|
|
||||||
// the connection pool.
|
|
||||||
dialNetwork string
|
|
||||||
// dialAddress specifies the address used to establish network connections within the connection pool.
|
|
||||||
dialAddress string
|
|
||||||
}
|
|
||||||
|
|
||||||
// PoolConn is a wrapper around net.Conn to modify the the behavior of net.Conn's Close() method.
|
|
||||||
type PoolConn struct {
|
|
||||||
net.Conn
|
|
||||||
mutex sync.RWMutex
|
|
||||||
pool *connPool
|
|
||||||
unusable bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close puts a given pool connection back to the pool instead of closing it.
|
|
||||||
func (c *PoolConn) Close() error {
|
|
||||||
c.mutex.RLock()
|
|
||||||
defer c.mutex.RUnlock()
|
|
||||||
|
|
||||||
if c.unusable {
|
|
||||||
if c.Conn != nil {
|
|
||||||
return c.Conn.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return c.pool.put(c.Conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkUnusable marks the connection not usable any more, to let the pool close it instead
|
|
||||||
// of returning it to pool.
|
|
||||||
func (c *PoolConn) MarkUnusable() {
|
|
||||||
c.mutex.Lock()
|
|
||||||
c.unusable = true
|
|
||||||
c.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewConnPool returns a new pool based on buffered channels with an initial
|
|
||||||
// capacity and maximum capacity. The DialContextFunc is used when the initial
|
|
||||||
// capacity is greater than zero to fill the pool. A zero initialCap doesn't
|
|
||||||
// fill the Pool until a new Get() is called. During a Get(), if there is no
|
|
||||||
// new connection available in the pool, a new connection will be created via
|
|
||||||
// the corresponding DialContextFunc() method.
|
|
||||||
func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc,
|
|
||||||
network, address string,
|
|
||||||
) (Pool, error) {
|
|
||||||
if initialCap < 0 || maxCap <= 0 || initialCap > maxCap {
|
|
||||||
return nil, ErrPoolInvalidCap
|
|
||||||
}
|
|
||||||
|
|
||||||
pool := &connPool{
|
|
||||||
conns: make(chan net.Conn, maxCap),
|
|
||||||
dialCtxFunc: dialCtxFunc,
|
|
||||||
dialAddress: address,
|
|
||||||
dialNetwork: network,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initial connections for the pool. Pool will be closed on connection error
|
|
||||||
for i := 0; i < initialCap; i++ {
|
|
||||||
conn, err := dialCtxFunc(ctx, network, address)
|
|
||||||
if err != nil {
|
|
||||||
pool.Close()
|
|
||||||
return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %w", err)
|
|
||||||
}
|
|
||||||
pool.conns <- conn
|
|
||||||
}
|
|
||||||
|
|
||||||
return pool, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get satisfies the Get() method of the Pool inteface. If there is no new
|
|
||||||
// connection available in the Pool, a new connection will be created via the
|
|
||||||
// DialContextFunc() method.
|
|
||||||
func (p *connPool) Get(ctx context.Context) (net.Conn, error) {
|
|
||||||
conns, dialCtxFunc := p.getConnsAndDialContext()
|
|
||||||
if conns == nil {
|
|
||||||
return nil, ErrClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
// wrap the connections into the custom net.Conn implementation that puts
|
|
||||||
// connections back to the pool
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, fmt.Errorf("failed to get connection: %w", ctx.Err())
|
|
||||||
case conn := <-conns:
|
|
||||||
if conn == nil {
|
|
||||||
return nil, ErrClosed
|
|
||||||
}
|
|
||||||
return p.wrapConn(conn), nil
|
|
||||||
default:
|
|
||||||
conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialContextFunc failed: %w", err)
|
|
||||||
}
|
|
||||||
return p.wrapConn(conn), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close terminates all connections in the pool and frees associated resources. Once closed,
|
|
||||||
// the pool is no longer usable.
|
|
||||||
func (p *connPool) Close() {
|
|
||||||
p.mutex.Lock()
|
|
||||||
conns := p.conns
|
|
||||||
p.conns = nil
|
|
||||||
p.dialCtxFunc = nil
|
|
||||||
p.dialAddress = ""
|
|
||||||
p.dialNetwork = ""
|
|
||||||
p.mutex.Unlock()
|
|
||||||
|
|
||||||
if conns == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
close(conns)
|
|
||||||
for conn := range conns {
|
|
||||||
_ = conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size returns the current number of connections in the connection pool.
|
|
||||||
func (p *connPool) Size() int {
|
|
||||||
conns, _ := p.getConnsAndDialContext()
|
|
||||||
return len(conns)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getConnsAndDialContext returns the connection channel and the DialContext function for the
|
|
||||||
// connection pool.
|
|
||||||
func (p *connPool) getConnsAndDialContext() (chan net.Conn, DialContextFunc) {
|
|
||||||
p.mutex.RLock()
|
|
||||||
conns := p.conns
|
|
||||||
dialCtxFunc := p.dialCtxFunc
|
|
||||||
p.mutex.RUnlock()
|
|
||||||
return conns, dialCtxFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
// put puts a passed connection back into the pool. If the pool is full or closed,
|
|
||||||
// conn is simply closed. A nil conn will be rejected with an error.
|
|
||||||
func (p *connPool) put(conn net.Conn) error {
|
|
||||||
if conn == nil {
|
|
||||||
return ErrNilConn
|
|
||||||
}
|
|
||||||
|
|
||||||
p.mutex.RLock()
|
|
||||||
defer p.mutex.RUnlock()
|
|
||||||
|
|
||||||
if p.conns == nil {
|
|
||||||
return conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case p.conns <- conn:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// wrapConn wraps a given net.Conn with a PoolConn, modifying the net.Conn's Close() method.
|
|
||||||
func (p *connPool) wrapConn(conn net.Conn) net.Conn {
|
|
||||||
poolconn := &PoolConn{pool: p}
|
|
||||||
poolconn.Conn = conn
|
|
||||||
return poolconn
|
|
||||||
}
|
|
423
connpool_test.go
423
connpool_test.go
|
@ -1,423 +0,0 @@
|
||||||
// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors
|
|
||||||
//
|
|
||||||
// SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
package mail
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewConnPool(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
serverPort := TestServerPortBase + 10
|
|
||||||
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
|
||||||
go func() {
|
|
||||||
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
|
|
||||||
t.Errorf("failed to start test server: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(time.Millisecond * 300)
|
|
||||||
|
|
||||||
pool, err := newConnPool(serverPort)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create connection pool: %s", err)
|
|
||||||
}
|
|
||||||
defer pool.Close()
|
|
||||||
if pool == nil {
|
|
||||||
t.Errorf("connection pool is nil")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if pool.Size() != 5 {
|
|
||||||
t.Errorf("expected 5 connections, got %d", pool.Size())
|
|
||||||
}
|
|
||||||
conn, err := pool.Get(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get connection: %s", err)
|
|
||||||
}
|
|
||||||
if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
|
||||||
t.Errorf("failed to write quit command to first connection: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnPool_Get_Type(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
serverPort := TestServerPortBase + 11
|
|
||||||
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
|
||||||
go func() {
|
|
||||||
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
|
|
||||||
t.Errorf("failed to start test server: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(time.Millisecond * 300)
|
|
||||||
|
|
||||||
pool, err := newConnPool(serverPort)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create connection pool: %s", err)
|
|
||||||
}
|
|
||||||
defer pool.Close()
|
|
||||||
|
|
||||||
conn, err := pool.Get(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ok := conn.(*PoolConn)
|
|
||||||
if !ok {
|
|
||||||
t.Error("received connection from pool is not of type PoolConn")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
|
||||||
t.Errorf("failed to write quit command to first connection: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnPool_Get(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
serverPort := TestServerPortBase + 12
|
|
||||||
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
|
||||||
go func() {
|
|
||||||
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
|
|
||||||
t.Errorf("failed to start test server: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(time.Millisecond * 300)
|
|
||||||
|
|
||||||
p, _ := newConnPool(serverPort)
|
|
||||||
defer p.Close()
|
|
||||||
|
|
||||||
conn, err := p.Get(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
|
||||||
t.Errorf("failed to write quit command to first connection: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.Size() != 4 {
|
|
||||||
t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 4; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
wgconn, err := p.Get(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
}
|
|
||||||
if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
|
||||||
t.Errorf("failed to write quit command to first connection: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if p.Size() != 0 {
|
|
||||||
t.Errorf("Get error. Expecting 0, got %d", p.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err = p.Get(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
}
|
|
||||||
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
|
||||||
t.Errorf("failed to write quit command to first connection: %s", err)
|
|
||||||
}
|
|
||||||
p.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPoolConn_Close(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
serverPort := TestServerPortBase + 13
|
|
||||||
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
|
||||||
go func() {
|
|
||||||
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
|
|
||||||
t.Errorf("failed to start test server: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(time.Millisecond * 300)
|
|
||||||
|
|
||||||
netDialer := net.Dialer{}
|
|
||||||
p, err := NewConnPool(context.Background(), 0, 30, netDialer.DialContext, "tcp",
|
|
||||||
fmt.Sprintf("127.0.0.1:%d", serverPort))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create connection pool: %s", err)
|
|
||||||
}
|
|
||||||
defer p.Close()
|
|
||||||
|
|
||||||
conns := make([]net.Conn, 30)
|
|
||||||
for i := 0; i < 30; i++ {
|
|
||||||
conn, _ := p.Get(context.Background())
|
|
||||||
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
|
||||||
t.Errorf("failed to write quit command to first connection: %s", err)
|
|
||||||
}
|
|
||||||
conns[i] = conn
|
|
||||||
}
|
|
||||||
for _, conn := range conns {
|
|
||||||
if err = conn.Close(); err != nil {
|
|
||||||
t.Errorf("failed to close connection: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.Size() != 30 {
|
|
||||||
t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := p.Get(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
}
|
|
||||||
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
|
||||||
t.Errorf("failed to write quit command to first connection: %s", err)
|
|
||||||
}
|
|
||||||
p.Close()
|
|
||||||
|
|
||||||
if err = conn.Close(); err != nil {
|
|
||||||
t.Errorf("failed to close connection: %s", err)
|
|
||||||
}
|
|
||||||
if p.Size() != 0 {
|
|
||||||
t.Errorf("closed pool shouldn't allow to put connections.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPoolConn_MarkUnusable(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
serverPort := TestServerPortBase + 14
|
|
||||||
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
|
||||||
go func() {
|
|
||||||
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
|
|
||||||
t.Errorf("failed to start test server: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(time.Millisecond * 300)
|
|
||||||
|
|
||||||
pool, _ := newConnPool(serverPort)
|
|
||||||
defer pool.Close()
|
|
||||||
|
|
||||||
conn, err := pool.Get(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
}
|
|
||||||
if err = conn.Close(); err != nil {
|
|
||||||
t.Errorf("failed to close connection: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
poolSize := pool.Size()
|
|
||||||
conn, err = pool.Get(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
}
|
|
||||||
if err = conn.Close(); err != nil {
|
|
||||||
t.Errorf("failed to close connection: %s", err)
|
|
||||||
}
|
|
||||||
if pool.Size() != poolSize {
|
|
||||||
t.Errorf("pool size is expected to be equal to initial size")
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err = pool.Get(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
}
|
|
||||||
if pc, ok := conn.(*PoolConn); !ok {
|
|
||||||
t.Errorf("this should never happen")
|
|
||||||
} else {
|
|
||||||
pc.MarkUnusable()
|
|
||||||
}
|
|
||||||
if err = conn.Close(); err != nil {
|
|
||||||
t.Errorf("failed to close connection: %s", err)
|
|
||||||
}
|
|
||||||
if pool.Size() != poolSize-1 {
|
|
||||||
t.Errorf("pool size is expected to be: %d but got: %d", poolSize-1, pool.Size())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnPool_Close(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
serverPort := TestServerPortBase + 15
|
|
||||||
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
|
||||||
go func() {
|
|
||||||
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
|
|
||||||
t.Errorf("failed to start test server: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(time.Millisecond * 300)
|
|
||||||
|
|
||||||
pool, err := newConnPool(serverPort)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create connection pool: %s", err)
|
|
||||||
}
|
|
||||||
pool.Close()
|
|
||||||
|
|
||||||
castPool := pool.(*connPool)
|
|
||||||
|
|
||||||
if castPool.conns != nil {
|
|
||||||
t.Error("closing pool failed: conns channel should be nil")
|
|
||||||
}
|
|
||||||
if castPool.dialCtxFunc != nil {
|
|
||||||
t.Error("closing pool failed: dialCtxFunc should be nil")
|
|
||||||
}
|
|
||||||
if castPool.dialAddress != "" {
|
|
||||||
t.Error("closing pool failed: dialAddress should be empty")
|
|
||||||
}
|
|
||||||
if castPool.dialNetwork != "" {
|
|
||||||
t.Error("closing pool failed: dialNetwork should be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := pool.Get(context.Background())
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("closing pool failed: getting new connection should return an error")
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Errorf("closing pool failed: getting new connection should return a nil-connection")
|
|
||||||
}
|
|
||||||
if pool.Size() != 0 {
|
|
||||||
t.Errorf("closing pool failed: pool size should be 0, but got: %d", pool.Size())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnPool_Concurrency(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
serverPort := TestServerPortBase + 16
|
|
||||||
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
|
||||||
go func() {
|
|
||||||
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
|
|
||||||
t.Errorf("failed to start test server: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(time.Millisecond * 300)
|
|
||||||
|
|
||||||
pool, err := newConnPool(serverPort)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create connection pool: %s", err)
|
|
||||||
}
|
|
||||||
defer pool.Close()
|
|
||||||
pipe := make(chan net.Conn)
|
|
||||||
|
|
||||||
getWg := sync.WaitGroup{}
|
|
||||||
closeWg := sync.WaitGroup{}
|
|
||||||
for i := 0; i < 30; i++ {
|
|
||||||
getWg.Add(1)
|
|
||||||
closeWg.Add(1)
|
|
||||||
go func() {
|
|
||||||
conn, err := pool.Get(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
}
|
|
||||||
pipe <- conn
|
|
||||||
getWg.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
conn := <-pipe
|
|
||||||
if conn == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = conn.Close(); err != nil {
|
|
||||||
t.Errorf("failed to close connection: %s", err)
|
|
||||||
}
|
|
||||||
closeWg.Done()
|
|
||||||
}()
|
|
||||||
getWg.Wait()
|
|
||||||
closeWg.Wait()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnPool_GetContextCancel(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
serverPort := TestServerPortBase + 17
|
|
||||||
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
|
||||||
go func() {
|
|
||||||
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
|
|
||||||
t.Errorf("failed to start test server: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(time.Millisecond * 300)
|
|
||||||
|
|
||||||
p, err := newConnPool(serverPort)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create connection pool: %s", err)
|
|
||||||
}
|
|
||||||
defer p.Close()
|
|
||||||
|
|
||||||
connCtx, connCancel := context.WithCancel(context.Background())
|
|
||||||
defer connCancel()
|
|
||||||
|
|
||||||
conn, err := p.Get(connCtx)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
|
||||||
t.Errorf("failed to write quit command to first connection: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.Size() != 4 {
|
|
||||||
t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 4; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
wgconn, err := p.Get(connCtx)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to get new connection from pool: %s", err)
|
|
||||||
}
|
|
||||||
if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
|
|
||||||
t.Errorf("failed to write quit command to first connection: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if p.Size() != 0 {
|
|
||||||
t.Errorf("Get error. Expecting 0, got %d", p.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
connCancel()
|
|
||||||
_, err = p.Get(connCtx)
|
|
||||||
if err == nil {
|
|
||||||
t.Errorf("getting new connection on canceled context should fail, but didn't")
|
|
||||||
}
|
|
||||||
p.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newConnPool(port int) (Pool, error) {
|
|
||||||
netDialer := net.Dialer{}
|
|
||||||
return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp",
|
|
||||||
fmt.Sprintf("127.0.0.1:%d", port))
|
|
||||||
}
|
|
14
smtp/smtp.go
14
smtp/smtp.go
|
@ -30,6 +30,7 @@ import (
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/wneessen/go-mail/log"
|
"github.com/wneessen/go-mail/log"
|
||||||
)
|
)
|
||||||
|
@ -472,6 +473,19 @@ func (c *Client) SetDSNRcptNotifyOption(d string) {
|
||||||
c.dsnrntype = d
|
c.dsnrntype = d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasConnection checks if the client has an active connection.
|
||||||
|
// Returns true if the `conn` field is not nil, indicating an active connection.
|
||||||
|
func (c *Client) HasConnection() bool {
|
||||||
|
return c.conn != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) UpdateDeadline(timeout time.Duration) error {
|
||||||
|
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||||
|
return fmt.Errorf("smtp: failed to update deadline: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// debugLog checks if the debug flag is set and if so logs the provided message to
|
// debugLog checks if the debug flag is set and if so logs the provided message to
|
||||||
// the log.Logger interface
|
// the log.Logger interface
|
||||||
func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) {
|
func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) {
|
||||||
|
|
Loading…
Reference in a new issue