mirror of
https://github.com/wneessen/go-mail.git
synced 2024-11-14 18:02:55 +01:00
Lock client connections and update deadline handling
Add mutex locking for client connections to ensure thread safety. Introduce `HasConnection` method to check active connections and `UpdateDeadline` method to handle timeout updates. Refactor connection handling in `checkConn` and `tls` methods accordingly.
This commit is contained in:
parent
8683917c3d
commit
3871b2be44
3 changed files with 27 additions and 7 deletions
18
client.go
18
client.go
|
@ -12,6 +12,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/wneessen/go-mail/log"
|
||||
|
@ -87,6 +88,7 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
|
|||
|
||||
// Client is the SMTP client struct
|
||||
type Client struct {
|
||||
mutex sync.RWMutex
|
||||
// connection is the net.Conn that the smtp.Client is based on
|
||||
connection net.Conn
|
||||
|
||||
|
@ -589,6 +591,9 @@ func (c *Client) setDefaultHelo() error {
|
|||
|
||||
// DialWithContext establishes a connection to the SMTP server with a given context.Context
|
||||
func (c *Client) DialWithContext(dialCtx context.Context) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
|
||||
defer cancel()
|
||||
|
||||
|
@ -602,17 +607,16 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
|
|||
c.dialContextFunc = tlsDialer.DialContext
|
||||
}
|
||||
}
|
||||
var err error
|
||||
c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
|
||||
connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr())
|
||||
if err != nil && c.fallbackPort != 0 {
|
||||
// TODO: should we somehow log or append the previous error?
|
||||
c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
|
||||
connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := smtp.NewClient(c.connection, c.host)
|
||||
client, err := smtp.NewClient(connection, c.host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -691,7 +695,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
|
|||
// checkConn makes sure that a required server connection is available and extends the
|
||||
// connection deadline
|
||||
func (c *Client) checkConn() error {
|
||||
if c.connection == nil {
|
||||
if !c.smtpClient.HasConnection() {
|
||||
return ErrNoActiveConnection
|
||||
}
|
||||
|
||||
|
@ -701,7 +705,7 @@ func (c *Client) checkConn() error {
|
|||
}
|
||||
}
|
||||
|
||||
if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil {
|
||||
if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil {
|
||||
return ErrDeadlineExtendFailed
|
||||
}
|
||||
return nil
|
||||
|
@ -715,7 +719,7 @@ func (c *Client) serverFallbackAddr() string {
|
|||
|
||||
// tls tries to make sure that the STARTTLS requirements are satisfied
|
||||
func (c *Client) tls() error {
|
||||
if c.connection == nil {
|
||||
if !c.smtpClient.HasConnection() {
|
||||
return ErrNoActiveConnection
|
||||
}
|
||||
if !c.useSSL && c.tlspolicy != NoTLS {
|
||||
|
|
|
@ -13,6 +13,8 @@ import (
|
|||
|
||||
// Send sends out the mail message
|
||||
func (c *Client) Send(messages ...*Msg) (returnErr error) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if err := c.checkConn(); err != nil {
|
||||
returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)}
|
||||
return
|
||||
|
|
14
smtp/smtp.go
14
smtp/smtp.go
|
@ -30,6 +30,7 @@ import (
|
|||
"net/textproto"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/wneessen/go-mail/log"
|
||||
)
|
||||
|
@ -472,6 +473,19 @@ func (c *Client) SetDSNRcptNotifyOption(d string) {
|
|||
c.dsnrntype = d
|
||||
}
|
||||
|
||||
// HasConnection checks if the client has an active connection.
|
||||
// Returns true if the `conn` field is not nil, indicating an active connection.
|
||||
func (c *Client) HasConnection() bool {
|
||||
return c.conn != nil
|
||||
}
|
||||
|
||||
func (c *Client) UpdateDeadline(timeout time.Duration) error {
|
||||
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return fmt.Errorf("smtp: failed to update deadline: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// debugLog checks if the debug flag is set and if so logs the provided message to
|
||||
// the log.Logger interface
|
||||
func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) {
|
||||
|
|
Loading…
Reference in a new issue