Compare commits

..

No commits in common. "07e7b17ae8a073feda2288cf38e8d865975e3a89" and "26ff177fb0f91f2afc71c6692b95726e477ec02a" have entirely different histories.

3 changed files with 35 additions and 522 deletions

View file

@ -2085,6 +2085,7 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
data, err := reader.ReadString('\n')
if err != nil {
fmt.Printf("unable to read from connection: %s\n", err)
return
}
if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") {
@ -2092,12 +2093,17 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
return
}
if err = writeLine("250-localhost.localdomain\r\n" + featureSet); err != nil {
fmt.Printf("unable to write to connection: %s\n", err)
return
}
for {
data, err = reader.ReadString('\n')
if err != nil {
if errors.Is(err, io.EOF) {
break
}
fmt.Println("Error reading data:", err)
break
}

View file

@ -5,23 +5,16 @@
package mail
import (
"context"
"errors"
"fmt"
"net"
"sync"
)
// Parts of the connection pool code is forked/took inspiration from https://github.com/fatih/pool/
// Thanks to Fatih Arslan and the project contributors for providing this great concurrency template.
// Parts of the connection pool code is forked from https://github.com/fatih/pool/
// Thanks to Fatih Arslan and the project contributors for providing this great
// concurrency template.
var (
// ErrClosed is returned when an operation is attempted on a closed connection pool.
ErrClosed = errors.New("connection pool is closed")
// ErrNilConn is returned when a nil connection is passed back to the connection pool.
ErrNilConn = errors.New("connection is nil")
// ErrPoolInvalidCap is returned when the connection pool's capacity settings are
// invalid (e.g., initial capacity is negative).
ErrPoolInvalidCap = errors.New("invalid connection pool capacity settings")
)
@ -37,56 +30,18 @@ type Pool interface {
// no longer usable.
Close()
// Size returns the current number of connections of the pool.
Size() int
// Len returns the current number of connections of the pool.
Len() int
}
// connPool implements the Pool interface
type connPool struct {
// mutex is used to synchronize access to the connection pool to ensure thread-safe operations.
// mutex is used to synchronize access to the connection pool to ensure thread-safe operations
mutex sync.RWMutex
// conns is a channel used to manage and distribute net.Conn objects within the connection pool.
// conns is a channel used to manage and distribute net.Conn objects within the connection pool
conns chan net.Conn
// dialCtxFunc represents the actual net.Conn returned by the DialContextFunc.
dialCtxFunc DialContextFunc
// dialContext is the context used for dialing new network connections within the connection pool.
dialContext context.Context
// dialNetwork specifies the network type (e.g., "tcp", "udp") used to establish connections in
// the connection pool.
dialNetwork string
// dialAddress specifies the address used to establish network connections within the connection pool.
dialAddress string
}
// PoolConn is a wrapper around net.Conn to modify the the behavior of net.Conn's Close() method.
type PoolConn struct {
net.Conn
mutex sync.RWMutex
pool *connPool
unusable bool
}
// Close puts a given pool connection back to the pool instead of closing it.
func (c *PoolConn) Close() error {
c.mutex.RLock()
defer c.mutex.RUnlock()
if c.unusable {
if c.Conn != nil {
return c.Conn.Close()
}
return nil
}
return c.pool.put(c.Conn)
}
// MarkUnusable marks the connection not usable any more, to let the pool close it instead
// of returning it to pool.
func (c *PoolConn) MarkUnusable() {
c.mutex.Lock()
c.unusable = true
c.mutex.Unlock()
// dialCtx represents the actual net.Conn returned by the DialContextFunc
dialCtx DialContextFunc
}
// NewConnPool returns a new pool based on buffered channels with an initial
@ -95,126 +50,40 @@ func (c *PoolConn) MarkUnusable() {
// fill the Pool until a new Get() is called. During a Get(), if there is no
// new connection available in the pool, a new connection will be created via
// the corresponding DialContextFunc() method.
func NewConnPool(ctx context.Context, initialCap, maxCap int, dialCtxFunc DialContextFunc,
network, address string) (Pool, error) {
func NewConnPool(initialCap, maxCap int, dialCtxFunc DialContextFunc) (Pool, error) {
if initialCap < 0 || maxCap <= 0 || initialCap > maxCap {
return nil, ErrPoolInvalidCap
}
pool := &connPool{
conns: make(chan net.Conn, maxCap),
dialCtxFunc: dialCtxFunc,
dialContext: ctx,
dialAddress: address,
dialNetwork: network,
dialCtx: dialCtxFunc,
}
// Initial connections for the pool. Pool will be closed on connection error
// create initial connections, if something goes wrong,
// just close the pool error out.
for i := 0; i < initialCap; i++ {
conn, err := dialCtxFunc(ctx, network, address)
/*
conn, err := dialCtxFunc()
if err != nil {
pool.Close()
return nil, fmt.Errorf("dialContextFunc is not able to fill the connection pool: %s", err)
return nil, fmt.Errorf("factory is not able to fill the pool: %s", err)
}
pool.conns <- conn
c.conns <- conn
*/
}
return pool, nil
}
// Get satisfies the Get() method of the Pool inteface. If there is no new
// connection available in the Pool, a new connection will be created via the
// DialContextFunc() method.
func (p *connPool) Get() (net.Conn, error) {
ctx, conns, dialCtxFunc := p.getConnsAndDialContext()
if conns == nil {
return nil, ErrClosed
}
// wrap the connections into the custom net.Conn implementation that puts
// connections back to the pool
select {
case <-ctx.Done():
return nil, ctx.Err()
case conn := <-conns:
if conn == nil {
return nil, ErrClosed
}
return p.wrapConn(conn), nil
default:
conn, err := dialCtxFunc(ctx, p.dialNetwork, p.dialAddress)
if err != nil {
return nil, err
}
return p.wrapConn(conn), nil
}
func (c *connPool) Get() (net.Conn, error) {
return nil, nil
}
// Close terminates all connections in the pool and frees associated resources. Once closed,
// the pool is no longer usable.
func (p *connPool) Close() {
p.mutex.Lock()
conns := p.conns
p.conns = nil
p.dialCtxFunc = nil
p.dialContext = nil
p.dialAddress = ""
p.dialNetwork = ""
p.mutex.Unlock()
if conns == nil {
func (c *connPool) Close() {
return
}
close(conns)
for conn := range conns {
_ = conn.Close()
}
}
// Size returns the current number of connections in the connection pool.
func (p *connPool) Size() int {
_, conns, _ := p.getConnsAndDialContext()
return len(conns)
}
// getConnsAndDialContext returns the connection channel and the DialContext function for the
// connection pool.
func (p *connPool) getConnsAndDialContext() (context.Context, chan net.Conn, DialContextFunc) {
p.mutex.RLock()
conns := p.conns
dialCtxFunc := p.dialCtxFunc
ctx := p.dialContext
p.mutex.RUnlock()
return ctx, conns, dialCtxFunc
}
// put puts a passed connection back into the pool. If the pool is full or closed,
// conn is simply closed. A nil conn will be rejected with an error.
func (p *connPool) put(conn net.Conn) error {
if conn == nil {
return ErrNilConn
}
p.mutex.RLock()
defer p.mutex.RUnlock()
if p.conns == nil {
return conn.Close()
}
select {
case p.conns <- conn:
return nil
default:
return conn.Close()
}
}
// wrapConn wraps a given net.Conn with a PoolConn, modifying the net.Conn's Close() method.
func (p *connPool) wrapConn(conn net.Conn) net.Conn {
poolconn := &PoolConn{pool: p}
poolconn.Conn = conn
return poolconn
func (c *connPool) Len() int {
return 0
}

View file

@ -1,362 +0,0 @@
// SPDX-FileCopyrightText: 2022-2024 The go-mail Authors
//
// SPDX-License-Identifier: MIT
package mail
import (
"context"
"fmt"
"net"
"sync"
"testing"
"time"
)
func TestNewConnPool(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 10
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer pool.Close()
if pool == nil {
t.Errorf("connection pool is nil")
return
}
if pool.Size() != 5 {
t.Errorf("expected 5 connections, got %d", pool.Size())
}
conn, err := pool.Get()
if err != nil {
t.Errorf("failed to get connection: %s", err)
}
if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}
func TestConnPool_Get_Type(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 11
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer pool.Close()
conn, err := pool.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
return
}
_, ok := conn.(*PoolConn)
if !ok {
t.Error("received connection from pool is not of type PoolConn")
return
}
if _, err := conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}
func TestConnPool_Get(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 12
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
p, _ := newConnPool(serverPort)
defer p.Close()
conn, err := p.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
return
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
if p.Size() != 4 {
t.Errorf("getting new connection from pool failed. Expected pool size: 4, got %d", p.Size())
}
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
wgconn, err := p.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = wgconn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
}()
}
wg.Wait()
if p.Size() != 0 {
t.Errorf("Get error. Expecting 0, got %d", p.Size())
}
conn, err = p.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
p.Close()
}
func TestPoolConn_Close(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 13
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
netDialer := net.Dialer{}
p, err := NewConnPool(context.Background(), 0, 30, netDialer.DialContext, "tcp",
fmt.Sprintf("127.0.0.1:%d", serverPort))
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer p.Close()
conns := make([]net.Conn, 30)
for i := 0; i < 30; i++ {
conn, _ := p.Get()
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
conns[i] = conn
}
for _, conn := range conns {
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
}
if p.Size() != 30 {
t.Errorf("failed to return all connections to pool. Expected pool size: 30, got %d", p.Size())
}
conn, err := p.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if _, err = conn.Write([]byte("EHLO test.localhost.localdomain\r\nQUIT\r\n")); err != nil {
t.Errorf("failed to write quit command to first connection: %s", err)
}
p.Close()
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
if p.Size() != 0 {
t.Errorf("closed pool shouldn't allow to put connections.")
}
}
func TestPoolConn_MarkUnusable(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 14
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, _ := newConnPool(serverPort)
defer pool.Close()
conn, err := pool.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
poolSize := pool.Size()
conn, err = pool.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
if pool.Size() != poolSize {
t.Errorf("pool size is expected to be equal to initial size")
}
conn, err = pool.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
if pc, ok := conn.(*PoolConn); !ok {
t.Errorf("this should never happen")
} else {
pc.MarkUnusable()
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
if pool.Size() != poolSize-1 {
t.Errorf("pool size is expected to be: %d but got: %d", poolSize-1, pool.Size())
}
}
func TestConnPool_Close(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 15
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
pool.Close()
castPool := pool.(*connPool)
if castPool.conns != nil {
t.Error("closing pool failed: conns channel should be nil")
}
if castPool.dialCtxFunc != nil {
t.Error("closing pool failed: dialCtxFunc should be nil")
}
if castPool.dialContext != nil {
t.Error("closing pool failed: dialContext should be nil")
}
if castPool.dialAddress != "" {
t.Error("closing pool failed: dialAddress should be empty")
}
if castPool.dialNetwork != "" {
t.Error("closing pool failed: dialNetwork should be empty")
}
conn, err := pool.Get()
if err == nil {
t.Errorf("closing pool failed: getting new connection should return an error")
}
if conn != nil {
t.Errorf("closing pool failed: getting new connection should return a nil-connection")
}
if pool.Size() != 0 {
t.Errorf("closing pool failed: pool size should be 0, but got: %d", pool.Size())
}
}
func TestConnPool_Concurrency(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 16
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, true, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 300)
pool, err := newConnPool(serverPort)
if err != nil {
t.Errorf("failed to create connection pool: %s", err)
}
defer pool.Close()
pipe := make(chan net.Conn)
getWg := sync.WaitGroup{}
closeWg := sync.WaitGroup{}
for i := 0; i < 30; i++ {
getWg.Add(1)
closeWg.Add(1)
go func() {
conn, err := pool.Get()
if err != nil {
t.Errorf("failed to get new connection from pool: %s", err)
}
pipe <- conn
getWg.Done()
}()
go func() {
conn := <-pipe
if conn == nil {
return
}
if err = conn.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}
closeWg.Done()
}()
getWg.Wait()
closeWg.Wait()
}
}
func newConnPool(port int) (Pool, error) {
netDialer := net.Dialer{}
return NewConnPool(context.Background(), 5, 30, netDialer.DialContext, "tcp",
fmt.Sprintf("127.0.0.1:%d", port))
}