Merge pull request #386 from wneessen/bug/385_concurrency-issue-in-dialandsendwithcontext

Fix concurrency issue in DialAndSendWithContext
This commit is contained in:
Winni Neessen 2024-11-22 16:36:10 +01:00 committed by GitHub
commit ead4067f2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 404 additions and 160 deletions

316
client.go
View file

@ -142,9 +142,6 @@ type (
// host is the hostname of the SMTP server we are connecting to.
host string
// isEncrypted indicates wether the Client connection is encrypted or not.
isEncrypted bool
// logAuthData indicates whether authentication-related data should be logged.
logAuthData bool
@ -806,7 +803,7 @@ func (c *Client) SetSSLPort(ssl bool, fallback bool) {
}
// SetDebugLog sets or overrides whether the Client is using debug logging. The debug logger will log incoming
// and outgoing communication between the Client and the server to os.Stderr.
// and outgoing communication between the client and the server to log.Logger that is defined on the Client.
//
// Note: The SMTP communication might include unencrypted authentication data, depending on whether you are using
// SMTP authentication and the type of authentication mechanism. This could pose a data protection risk. Use
@ -815,9 +812,26 @@ func (c *Client) SetSSLPort(ssl bool, fallback bool) {
// Parameters:
// - val: A boolean value indicating whether to enable (true) or disable (false) debug logging.
func (c *Client) SetDebugLog(val bool) {
c.SetDebugLogWithSMTPClient(c.smtpClient, val)
}
// SetDebugLogWithSMTPClient sets or overrides whether the provided smtp.Client is using debug logging.
// The debug logger will log incoming and outgoing communication between the client and the server to
// log.Logger that is defined on the Client.
//
// Note: The SMTP communication might include unencrypted authentication data, depending on whether you are using
// SMTP authentication and the type of authentication mechanism. This could pose a data protection risk. Use
// debug logging with caution.
//
// Parameters:
// - client: A pointer to the smtp.Client that handles the connection to the server.
// - val: A boolean value indicating whether to enable (true) or disable (false) debug logging.
func (c *Client) SetDebugLogWithSMTPClient(client *smtp.Client, val bool) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.useDebugLog = val
if c.smtpClient != nil {
c.smtpClient.SetDebugLog(val)
if client != nil {
client.SetDebugLog(val)
}
}
@ -830,9 +844,24 @@ func (c *Client) SetDebugLog(val bool) {
// Parameters:
// - logger: A logger that satisfies the log.Logger interface to be set for the Client.
func (c *Client) SetLogger(logger log.Logger) {
c.SetLoggerWithSMTPClient(c.smtpClient, logger)
}
// SetLoggerWithSMTPClient sets or overrides the custom logger currently used by the provided smtp.Client.
// The logger must satisfy the log.Logger interface and is only utilized when debug logging is enabled on
// the provided smtp.Client.
//
// By default, log.Stdlog is used if no custom logger is provided.
//
// Parameters:
// - client: A pointer to the smtp.Client that handles the connection to the server.
// - logger: A logger that satisfies the log.Logger interface to be set for the Client.
func (c *Client) SetLoggerWithSMTPClient(client *smtp.Client, logger log.Logger) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.logger = logger
if c.smtpClient != nil {
c.smtpClient.SetLogger(logger)
if client != nil {
client.SetLogger(logger)
}
}
@ -926,67 +955,91 @@ func (c *Client) SetLogAuthData(logAuth bool) {
// SMTP server.
//
// Parameters:
// - dialCtx: The context.Context used to control the connection timeout and cancellation.
// - ctxDial: The context.Context used to control the connection timeout and cancellation.
//
// Returns:
// - An error if the connection to the SMTP server fails or any subsequent command fails.
func (c *Client) DialWithContext(dialCtx context.Context) error {
c.mutex.Lock()
defer c.mutex.Unlock()
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
defer cancel()
if c.dialContextFunc == nil {
netDialer := net.Dialer{}
c.dialContextFunc = netDialer.DialContext
if c.useSSL {
tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig}
c.isEncrypted = true
c.dialContextFunc = tlsDialer.DialContext
}
}
connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
func (c *Client) DialWithContext(ctxDial context.Context) error {
client, err := c.DialToSMTPClientWithContext(ctxDial)
if err != nil {
return err
}
c.mutex.Lock()
c.smtpClient = client
c.mutex.Unlock()
return nil
}
// DialToSMTPClientWithContext establishes and configures a smtp.Client connection using
// the provided context.
//
// This function uses the provided context to manage the connection deadline and cancellation.
// It dials the SMTP server using the Client's configured DialContextFunc or a default dialer.
// If SSL is enabled, it uses a TLS connection. After successfully connecting, it initializes
// an smtp.Client, sends the HELO/EHLO command, and optionally performs STARTTLS and SMTP AUTH
// based on the Client's configuration. Debug and authentication logging are enabled if
// configured.
//
// Parameters:
// - ctxDial: The context used to control the connection timeout and cancellation.
//
// Returns:
// - A pointer to the initialized smtp.Client.
// - An error if the connection fails, the smtp.Client cannot be created, or any subsequent commands fail.
func (c *Client) DialToSMTPClientWithContext(ctxDial context.Context) (*smtp.Client, error) {
c.mutex.RLock()
defer c.mutex.RUnlock()
ctx, cancel := context.WithDeadline(ctxDial, time.Now().Add(c.connTimeout))
defer cancel()
isEncrypted := false
dialContextFunc := c.dialContextFunc
if c.dialContextFunc == nil {
netDialer := net.Dialer{}
dialContextFunc = netDialer.DialContext
if c.useSSL {
tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig}
isEncrypted = true
dialContextFunc = tlsDialer.DialContext
}
}
connection, err := dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
connection, err = dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
if err != nil {
return nil, err
}
client, err := smtp.NewClient(connection, c.host)
if err != nil {
return err
return nil, err
}
if client == nil {
return fmt.Errorf("SMTP client is nil")
}
c.smtpClient = client
if c.logger != nil {
c.smtpClient.SetLogger(c.logger)
client.SetLogger(c.logger)
}
if c.useDebugLog {
c.smtpClient.SetDebugLog(true)
client.SetDebugLog(true)
}
if c.logAuthData {
c.smtpClient.SetLogAuthData()
client.SetLogAuthData()
}
if err = c.smtpClient.Hello(c.helo); err != nil {
return err
if err = client.Hello(c.helo); err != nil {
return nil, err
}
if err = c.tls(); err != nil {
return err
if err = c.tls(client, &isEncrypted); err != nil {
return nil, err
}
if err = c.auth(); err != nil {
return err
if err = c.auth(client, isEncrypted); err != nil {
return nil, err
}
return nil
return client, nil
}
// Close terminates the connection to the SMTP server, returning an error if the disconnection
@ -999,10 +1052,27 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
// Returns:
// - An error if the disconnection fails; otherwise, returns nil.
func (c *Client) Close() error {
if c.smtpClient == nil || !c.smtpClient.HasConnection() {
return c.CloseWithSMTPClient(c.smtpClient)
}
// CloseWithSMTPClient terminates the connection of the provided smtp.Client to the SMTP server,
// returning an error if the disconnection fails. If the connection is already closed, this
// method is a no-op and disregards any error.
//
// This function checks if the smtp.Client connection is active. If not, it simply returns
// without any action. If the connection is active, it attempts to gracefully close the
// connection using the Quit method.
//
// Parameters:
// - client: A pointer to the smtp.Client that handles the connection to the server.
//
// Returns:
// - An error if the disconnection fails; otherwise, returns nil.
func (c *Client) CloseWithSMTPClient(client *smtp.Client) error {
if client == nil || !client.HasConnection() {
return nil
}
if err := c.smtpClient.Quit(); err != nil {
if err := client.Quit(); err != nil {
return fmt.Errorf("failed to close SMTP client: %w", err)
}
@ -1016,12 +1086,29 @@ func (c *Client) Close() error {
// the command fails, an error is returned.
//
// Returns:
// - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil.
// - An error if the connection check fails or if sending the RSET command fails;
// otherwise, returns nil.
func (c *Client) Reset() error {
if err := c.checkConn(); err != nil {
return c.ResetWithSMTPClient(c.smtpClient)
}
// ResetWithSMTPClient sends an SMTP RSET command to the provided smtp.Client, to reset
// the state of the current SMTP session.
//
// This method checks the connection to the SMTP server and, if the connection is valid,
// it sends an RSET command to reset the session state. If the connection is invalid or
// the command fails, an error is returned.
//
// Parameters:
// - client: A pointer to the smtp.Client that handles the connection to the server.
//
// Returns:
// - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil.
func (c *Client) ResetWithSMTPClient(client *smtp.Client) error {
if err := c.checkConn(client); err != nil {
return err
}
if err := c.smtpClient.Reset(); err != nil {
if err := client.Reset(); err != nil {
return fmt.Errorf("failed to send RSET to SMTP client: %w", err)
}
@ -1061,24 +1148,47 @@ func (c *Client) DialAndSend(messages ...*Msg) error {
// - An error if the connection fails, if sending the messages fails, or if closing the
// connection fails; otherwise, returns nil.
func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error {
c.sendMutex.Lock()
defer c.sendMutex.Unlock()
if err := c.DialWithContext(ctx); err != nil {
client, err := c.DialToSMTPClientWithContext(ctx)
if err != nil {
return fmt.Errorf("dial failed: %w", err)
}
defer func() {
_ = c.Close()
_ = c.CloseWithSMTPClient(client)
}()
if err := c.Send(messages...); err != nil {
if err = c.SendWithSMTPClient(client, messages...); err != nil {
return fmt.Errorf("send failed: %w", err)
}
if err := c.Close(); err != nil {
if err = c.CloseWithSMTPClient(client); err != nil {
return fmt.Errorf("failed to close connection: %w", err)
}
return nil
}
// Send attempts to send one or more Msg using the SMTP client that is assigned to the Client.
// If the Client has no active connection to the server, Send will fail with an error. For
// each of the provided Msg, it will associate a SendError with the Msg in case of a
// transmission or delivery error.
//
// This method first checks for an active connection to the SMTP server. If the connection is
// not valid, it returns a SendError. It then iterates over the provided messages, attempting
// to send each one. If an error occurs during sending, the method records the error and
// associates it with the corresponding Msg. If multiple errors are encountered, it aggregates
// them into a single SendError to be returned.
//
// Parameters:
// - client: A pointer to the smtp.Client that holds the connection to the SMTP server
// - messages: A variadic list of pointers to Msg objects to be sent.
//
// Returns:
// - An error that represents the sending result, which may include multiple SendErrors if
// any occurred; otherwise, returns nil.
func (c *Client) Send(messages ...*Msg) (returnErr error) {
c.sendMutex.Lock()
defer c.sendMutex.Unlock()
return c.SendWithSMTPClient(c.smtpClient, messages...)
}
// auth attempts to authenticate the client using SMTP AUTH mechanisms. It checks the connection,
// determines the supported authentication methods, and applies the appropriate authentication
// type. An error is returned if authentication fails.
@ -1098,16 +1208,17 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
// Returns:
// - An error if the connection check fails, if no supported authentication method is found,
// or if the authentication process fails.
func (c *Client) auth() error {
func (c *Client) auth(client *smtp.Client, isEnc bool) error {
var smtpAuth smtp.Auth
if c.smtpAuth == nil && c.smtpAuthType != SMTPAuthNoAuth {
hasSMTPAuth, smtpAuthType := c.smtpClient.Extension("AUTH")
hasSMTPAuth, smtpAuthType := client.Extension("AUTH")
if !hasSMTPAuth {
return fmt.Errorf("server does not support SMTP AUTH")
}
authType := c.smtpAuthType
if c.smtpAuthType == SMTPAuthAutoDiscover {
discoveredType, err := c.authTypeAutoDiscover(smtpAuthType)
discoveredType, err := c.authTypeAutoDiscover(smtpAuthType, isEnc)
if err != nil {
return err
}
@ -1119,74 +1230,74 @@ func (c *Client) auth() error {
if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) {
return ErrPlainAuthNotSupported
}
c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false)
smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false)
case SMTPAuthPlainNoEnc:
if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) {
return ErrPlainAuthNotSupported
}
c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true)
smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true)
case SMTPAuthLogin:
if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) {
return ErrLoginAuthNotSupported
}
c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false)
smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false)
case SMTPAuthLoginNoEnc:
if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) {
return ErrLoginAuthNotSupported
}
c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true)
smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true)
case SMTPAuthCramMD5:
if !strings.Contains(smtpAuthType, string(SMTPAuthCramMD5)) {
return ErrCramMD5AuthNotSupported
}
c.smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass)
smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass)
case SMTPAuthXOAUTH2:
if !strings.Contains(smtpAuthType, string(SMTPAuthXOAUTH2)) {
return ErrXOauth2AuthNotSupported
}
c.smtpAuth = smtp.XOAuth2Auth(c.user, c.pass)
smtpAuth = smtp.XOAuth2Auth(c.user, c.pass)
case SMTPAuthSCRAMSHA1:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1)) {
return ErrSCRAMSHA1AuthNotSupported
}
c.smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass)
smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass)
case SMTPAuthSCRAMSHA256:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256)) {
return ErrSCRAMSHA256AuthNotSupported
}
c.smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass)
smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass)
case SMTPAuthSCRAMSHA1PLUS:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1PLUS)) {
return ErrSCRAMSHA1PLUSAuthNotSupported
}
tlsConnState, err := c.smtpClient.GetTLSConnectionState()
tlsConnState, err := client.GetTLSConnectionState()
if err != nil {
return err
}
c.smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState)
smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState)
case SMTPAuthSCRAMSHA256PLUS:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256PLUS)) {
return ErrSCRAMSHA256PLUSAuthNotSupported
}
tlsConnState, err := c.smtpClient.GetTLSConnectionState()
tlsConnState, err := client.GetTLSConnectionState()
if err != nil {
return err
}
c.smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState)
smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState)
default:
return fmt.Errorf("unsupported SMTP AUTH type %q", c.smtpAuthType)
}
}
if c.smtpAuth != nil {
if err := c.smtpClient.Auth(c.smtpAuth); err != nil {
if smtpAuth != nil {
if err := client.Auth(smtpAuth); err != nil {
return fmt.Errorf("SMTP AUTH failed: %w", err)
}
}
return nil
}
func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) {
func (c *Client) authTypeAutoDiscover(supported string, isEnc bool) (SMTPAuthType, error) {
if supported == "" {
return "", ErrNoSupportedAuthDiscovered
}
@ -1194,7 +1305,7 @@ func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) {
SMTPAuthSCRAMSHA256PLUS, SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1PLUS, SMTPAuthSCRAMSHA1,
SMTPAuthXOAUTH2, SMTPAuthCramMD5, SMTPAuthPlain, SMTPAuthLogin,
}
if !c.isEncrypted {
if !isEnc {
preferList = []SMTPAuthType{SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5}
}
mechs := strings.Split(supported, " ")
@ -1231,13 +1342,13 @@ func sliceContains(slice []string, item string) bool {
//
// Returns:
// - An error if any part of the sending process fails; otherwise, returns nil.
func (c *Client) sendSingleMsg(message *Msg) error {
c.mutex.Lock()
defer c.mutex.Unlock()
escSupport, _ := c.smtpClient.Extension("ENHANCEDSTATUSCODES")
func (c *Client) sendSingleMsg(client *smtp.Client, message *Msg) error {
c.mutex.RLock()
defer c.mutex.RUnlock()
escSupport, _ := client.Extension("ENHANCEDSTATUSCODES")
if message.encoding == NoEncoding {
if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok {
if ok, _ := client.Extension("8BITMIME"); !ok {
return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message}
}
}
@ -1260,16 +1371,16 @@ func (c *Client) sendSingleMsg(message *Msg) error {
if c.requestDSN {
if c.dsnReturnType != "" {
c.smtpClient.SetDSNMailReturnOption(string(c.dsnReturnType))
client.SetDSNMailReturnOption(string(c.dsnReturnType))
}
}
if err = c.smtpClient.Mail(from); err != nil {
if err = client.Mail(from); err != nil {
retError := &SendError{
Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err),
affectedMsg: message, errcode: errorCode(err),
enhancedStatusCode: enhancedStatusCode(err, escSupport),
}
if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil {
if resetSendErr := client.Reset(); resetSendErr != nil {
retError.errlist = append(retError.errlist, resetSendErr)
}
return retError
@ -1279,9 +1390,9 @@ func (c *Client) sendSingleMsg(message *Msg) error {
rcptSendErr.errlist = make([]error, 0)
rcptSendErr.rcpt = make([]string, 0)
rcptNotifyOpt := strings.Join(c.dsnRcptNotifyType, ",")
c.smtpClient.SetDSNRcptNotifyOption(rcptNotifyOpt)
client.SetDSNRcptNotifyOption(rcptNotifyOpt)
for _, rcpt := range rcpts {
if err = c.smtpClient.Rcpt(rcpt); err != nil {
if err = client.Rcpt(rcpt); err != nil {
rcptSendErr.Reason = ErrSMTPRcptTo
rcptSendErr.errlist = append(rcptSendErr.errlist, err)
rcptSendErr.rcpt = append(rcptSendErr.rcpt, rcpt)
@ -1292,12 +1403,12 @@ func (c *Client) sendSingleMsg(message *Msg) error {
}
}
if hasError {
if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil {
if resetSendErr := client.Reset(); resetSendErr != nil {
rcptSendErr.errlist = append(rcptSendErr.errlist, resetSendErr)
}
return rcptSendErr
}
writer, err := c.smtpClient.Data()
writer, err := client.Data()
if err != nil {
return &SendError{
Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err),
@ -1322,7 +1433,7 @@ func (c *Client) sendSingleMsg(message *Msg) error {
}
message.isDelivered = true
if err = c.Reset(); err != nil {
if err = c.ResetWithSMTPClient(client); err != nil {
return &SendError{
Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err),
affectedMsg: message, errcode: errorCode(err),
@ -1344,21 +1455,24 @@ func (c *Client) sendSingleMsg(message *Msg) error {
// Returns:
// - An error if there is no active connection, if the NOOP command fails, or if extending
// the deadline fails; otherwise, returns nil.
func (c *Client) checkConn() error {
if c.smtpClient == nil {
func (c *Client) checkConn(client *smtp.Client) error {
if client == nil {
return ErrNoActiveConnection
}
if !c.smtpClient.HasConnection() {
if !client.HasConnection() {
return ErrNoActiveConnection
}
if !c.noNoop {
if err := c.smtpClient.Noop(); err != nil {
c.mutex.RLock()
noNoop := c.noNoop
c.mutex.RUnlock()
if !noNoop {
if err := client.Noop(); err != nil {
return ErrNoActiveConnection
}
}
if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil {
if err := client.UpdateDeadline(c.connTimeout); err != nil {
return ErrDeadlineExtendFailed
}
return nil
@ -1405,10 +1519,10 @@ func (c *Client) setDefaultHelo() error {
// Returns:
// - An error if there is no active connection, if STARTTLS is required but not supported,
// or if there are issues during the TLS handshake; otherwise, returns nil.
func (c *Client) tls() error {
func (c *Client) tls(client *smtp.Client, isEnc *bool) error {
if !c.useSSL && c.tlspolicy != NoTLS {
hasStartTLS := false
extension, _ := c.smtpClient.Extension("STARTTLS")
extension, _ := client.Extension("STARTTLS")
if c.tlspolicy == TLSMandatory {
hasStartTLS = true
if !extension {
@ -1422,21 +1536,21 @@ func (c *Client) tls() error {
}
}
if hasStartTLS {
if err := c.smtpClient.StartTLS(c.tlsconfig); err != nil {
if err := client.StartTLS(c.tlsconfig); err != nil {
return err
}
}
tlsConnState, err := c.smtpClient.GetTLSConnectionState()
tlsConnState, err := client.GetTLSConnectionState()
if err != nil {
switch {
case errors.Is(err, smtp.ErrNonTLSConnection):
c.isEncrypted = false
*isEnc = false
return nil
default:
return fmt.Errorf("failed to get TLS connection state: %w", err)
}
}
c.isEncrypted = tlsConnState.HandshakeComplete
*isEnc = tlsConnState.HandshakeComplete
}
return nil
}

View file

@ -7,12 +7,16 @@
package mail
import "errors"
import (
"errors"
// Send attempts to send one or more Msg using the Client connection to the SMTP server.
// If the Client has no active connection to the server, Send will fail with an error. For each
// of the provided Msg, it will associate a SendError with the Msg in case of a transmission
// or delivery error.
"github.com/wneessen/go-mail/smtp"
)
// SendWithSMTPClient attempts to send one or more Msg using a provided smtp.Client with an
// established connection to the SMTP server. If the smtp.Client has no active connection to
// the server, SendWithSMTPClient will fail with an error. For each of the provided Msg, it
// will associate a SendError with the Msg in case of a transmission or delivery error.
//
// This method first checks for an active connection to the SMTP server. If the connection is
// not valid, it returns a SendError. It then iterates over the provided messages, attempting
@ -21,17 +25,18 @@ import "errors"
// them into a single SendError to be returned.
//
// Parameters:
// - client: A pointer to the smtp.Client that holds the connection to the SMTP server
// - messages: A variadic list of pointers to Msg objects to be sent.
//
// Returns:
// - An error that represents the sending result, which may include multiple SendErrors if
// any occurred; otherwise, returns nil.
func (c *Client) Send(messages ...*Msg) error {
func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) error {
escSupport := false
if c.smtpClient != nil {
escSupport, _ = c.smtpClient.Extension("ENHANCEDSTATUSCODES")
if client != nil {
escSupport, _ = client.Extension("ENHANCEDSTATUSCODES")
}
if err := c.checkConn(); err != nil {
if err := c.checkConn(client); err != nil {
return &SendError{
Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err),
errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport),
@ -39,7 +44,7 @@ func (c *Client) Send(messages ...*Msg) error {
}
var errs []*SendError
for id, message := range messages {
if sendErr := c.sendSingleMsg(message); sendErr != nil {
if sendErr := c.sendSingleMsg(client, message); sendErr != nil {
messages[id].sendError = sendErr
var msgSendErr *SendError

View file

@ -9,29 +9,34 @@ package mail
import (
"errors"
"github.com/wneessen/go-mail/smtp"
)
// Send attempts to send one or more Msg using the Client connection to the SMTP server.
// If the Client has no active connection to the server, Send will fail with an error. For each
// of the provided Msg, it will associate a SendError with the Msg in case of a transmission
// or delivery error.
// SendWithSMTPClient attempts to send one or more Msg using a provided smtp.Client with an
// established connection to the SMTP server. If the smtp.Client has no active connection to
// the server, SendWithSMTPClient will fail with an error. For each of the provided Msg, it
// will associate a SendError with the Msg in case of a transmission or delivery error.
//
// This method first checks for an active connection to the SMTP server. If the connection is
// not valid, it returns an error wrapped in a SendError. It then iterates over the provided
// messages, attempting to send each one. If an error occurs during sending, the method records
// the error and associates it with the corresponding Msg.
// not valid, it returns a SendError. It then iterates over the provided messages, attempting
// to send each one. If an error occurs during sending, the method records the error and
// associates it with the corresponding Msg. If multiple errors are encountered, it aggregates
// them into a single SendError to be returned.
//
// Parameters:
// - client: A pointer to the smtp.Client that holds the connection to the SMTP server
// - messages: A variadic list of pointers to Msg objects to be sent.
//
// Returns:
// - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil.
func (c *Client) Send(messages ...*Msg) (returnErr error) {
// - An error that represents the sending result, which may include multiple SendErrors if
// any occurred; otherwise, returns nil.
func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) (returnErr error) {
escSupport := false
if c.smtpClient != nil {
escSupport, _ = c.smtpClient.Extension("ENHANCEDSTATUSCODES")
if client != nil {
escSupport, _ = client.Extension("ENHANCEDSTATUSCODES")
}
if err := c.checkConn(); err != nil {
if err := c.checkConn(client); err != nil {
returnErr = &SendError{
Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err),
errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport),
@ -45,7 +50,7 @@ func (c *Client) Send(messages ...*Msg) (returnErr error) {
}()
for id, message := range messages {
if sendErr := c.sendSingleMsg(message); sendErr != nil {
if sendErr := c.sendSingleMsg(client, message); sendErr != nil {
messages[id].sendError = sendErr
errs = append(errs, sendErr)
}

View file

@ -1772,11 +1772,8 @@ func TestClient_DialWithContext(t *testing.T) {
t.Errorf("failed to close the client: %s", err)
}
})
if client.smtpClient == nil {
t.Errorf("client with invalid HELO should still have a smtp client, got nil")
}
if !client.smtpClient.HasConnection() {
t.Errorf("client with invalid HELO should still have a smtp client connection, got nil")
if client.smtpClient != nil {
t.Error("client with invalid HELO should not have a smtp client")
}
})
t.Run("fail on base port and fallback", func(t *testing.T) {
@ -1825,11 +1822,8 @@ func TestClient_DialWithContext(t *testing.T) {
if err = client.DialWithContext(ctxDial); err == nil {
t.Fatalf("connection was supposed to fail, but didn't")
}
if client.smtpClient == nil {
t.Fatalf("client has no smtp client")
}
if !client.smtpClient.HasConnection() {
t.Errorf("client has no connection")
if client.smtpClient != nil {
t.Fatalf("client is not supposed to have a smtp client")
}
})
t.Run("connect with failing auth", func(t *testing.T) {
@ -2297,6 +2291,7 @@ func TestClient_DialAndSendWithContext(t *testing.T) {
t.Errorf("client was supposed to fail on dial")
}
})
// https://github.com/wneessen/go-mail/issues/380
t.Run("concurrent sending via DialAndSendWithContext", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -2336,6 +2331,44 @@ func TestClient_DialAndSendWithContext(t *testing.T) {
}
wg.Wait()
})
// https://github.com/wneessen/go-mail/issues/385
t.Run("concurrent sending via DialAndSendWithContext on receiver func", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
PortAdder.Add(1)
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, t, &serverProps{
FeatureSet: featureSet,
ListenPort: serverPort,
}); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 30)
client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS))
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
sender := testSender{client}
ctxDial := context.Background()
wg := sync.WaitGroup{}
for i := 0; i < 5; i++ {
wg.Add(1)
msg := testMessage(t)
go func() {
defer wg.Done()
if goroutineErr := sender.Send(ctxDial, msg); goroutineErr != nil {
t.Errorf("failed to send message: %s", goroutineErr)
}
}()
}
wg.Wait()
})
}
func TestClient_auth(t *testing.T) {
@ -2574,8 +2607,8 @@ func TestClient_authTypeAutoDiscover(t *testing.T) {
}
for _, tt := range tests {
t.Run("AutoDiscover selects the strongest auth type: "+string(tt.expect), func(t *testing.T) {
client := &Client{smtpAuthType: SMTPAuthAutoDiscover, isEncrypted: tt.tls}
authType, err := client.authTypeAutoDiscover(tt.supported)
client := &Client{smtpAuthType: SMTPAuthAutoDiscover}
authType, err := client.authTypeAutoDiscover(tt.supported, tt.tls)
if err != nil && !tt.shouldFail {
t.Fatalf("failed to auto discover auth type: %s", err)
}
@ -2709,6 +2742,82 @@ func TestClient_Send(t *testing.T) {
})
}
func TestClient_DialToSMTPClientWithContext(t *testing.T) {
t.Run("establish a new client connection", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
PortAdder.Add(1)
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, t, &serverProps{
FeatureSet: featureSet,
ListenPort: serverPort,
}); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 30)
ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500)
t.Cleanup(cancelDial)
client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS))
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
smtpClient, err := client.DialToSMTPClientWithContext(ctxDial)
if err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
t.Skip("failed to connect to the test server due to timeout")
}
t.Fatalf("failed to connect to test server: %s", err)
}
t.Cleanup(func() {
if err := client.CloseWithSMTPClient(smtpClient); err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
t.Skip("failed to close the test server connection due to timeout")
}
t.Errorf("failed to close client: %s", err)
}
})
if smtpClient == nil {
t.Fatal("expected SMTP client, got nil")
}
if !smtpClient.HasConnection() {
t.Fatal("expected connection on smtp client")
}
if ok, _ := smtpClient.Extension("DSN"); !ok {
t.Error("expected DSN extension but it was not found")
}
})
t.Run("dial to SMTP server fails on first client writeFile", func(t *testing.T) {
var fake faker
fake.ReadWriter = struct {
io.Reader
io.Writer
}{
failReadWriteSeekCloser{},
failReadWriteSeekCloser{},
}
ctxDial, cancelDial := context.WithTimeout(context.Background(), time.Millisecond*500)
t.Cleanup(cancelDial)
client, err := NewClient(DefaultHost, WithDialContextFunc(getFakeDialFunc(fake)))
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
_, err = client.DialToSMTPClientWithContext(ctxDial)
if err == nil {
t.Fatal("expected connection to fake to fail")
}
})
}
func TestClient_sendSingleMsg(t *testing.T) {
t.Run("connect and send email", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
@ -2748,7 +2857,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err != nil {
if err = client.sendSingleMsg(client.smtpClient, message); err != nil {
t.Errorf("failed to send message: %s", err)
}
})
@ -2791,7 +2900,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err == nil {
if err = client.sendSingleMsg(client.smtpClient, message); err == nil {
t.Errorf("client should have failed to send message")
}
})
@ -2836,7 +2945,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err == nil {
if err = client.sendSingleMsg(client.smtpClient, message); err == nil {
t.Errorf("client should have failed to send message")
}
var sendErr *SendError
@ -2886,7 +2995,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err == nil {
if err = client.sendSingleMsg(client.smtpClient, message); err == nil {
t.Errorf("client should have failed to send message")
}
var sendErr *SendError
@ -2936,7 +3045,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err == nil {
if err = client.sendSingleMsg(client.smtpClient, message); err == nil {
t.Errorf("client should have failed to send message")
}
var sendErr *SendError
@ -2986,7 +3095,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err != nil {
if err = client.sendSingleMsg(client.smtpClient, message); err != nil {
t.Errorf("failed to send message: %s", err)
}
})
@ -3029,7 +3138,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err == nil {
if err = client.sendSingleMsg(client.smtpClient, message); err == nil {
t.Errorf("client should have failed to send message")
}
var sendErr *SendError
@ -3080,7 +3189,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err == nil {
if err = client.sendSingleMsg(client.smtpClient, message); err == nil {
t.Errorf("client should have failed to send message")
}
var sendErr *SendError
@ -3131,7 +3240,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err == nil {
if err = client.sendSingleMsg(client.smtpClient, message); err == nil {
t.Errorf("client should have failed to send message")
}
var sendErr *SendError
@ -3181,7 +3290,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err == nil {
if err = client.sendSingleMsg(client.smtpClient, message); err == nil {
t.Errorf("client should have failed to send message")
}
var sendErr *SendError
@ -3231,7 +3340,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err == nil {
if err = client.sendSingleMsg(client.smtpClient, message); err == nil {
t.Errorf("client should have failed to send message")
}
var sendErr *SendError
@ -3281,7 +3390,7 @@ func TestClient_sendSingleMsg(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.sendSingleMsg(message); err == nil {
if err = client.sendSingleMsg(client.smtpClient, message); err == nil {
t.Error("expected mail delivery to fail")
}
var sendErr *SendError
@ -3334,7 +3443,7 @@ func TestClient_checkConn(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.checkConn(); err != nil {
if err = client.checkConn(client.smtpClient); err != nil {
t.Errorf("failed to check connection: %s", err)
}
})
@ -3375,7 +3484,7 @@ func TestClient_checkConn(t *testing.T) {
t.Errorf("failed to close client: %s", err)
}
})
if err = client.checkConn(); err == nil {
if err = client.checkConn(client.smtpClient); err == nil {
t.Errorf("client should have failed on connection check")
}
if !errors.Is(err, ErrNoActiveConnection) {
@ -3387,7 +3496,7 @@ func TestClient_checkConn(t *testing.T) {
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
if err = client.checkConn(); err == nil {
if err = client.checkConn(client.smtpClient); err == nil {
t.Errorf("client should have failed on connection check")
}
if !errors.Is(err, ErrNoActiveConnection) {
@ -3611,24 +3720,20 @@ func TestClient_XOAuth2OnFaker(t *testing.T) {
}
if err = c.DialWithContext(context.Background()); err == nil {
t.Fatal("expected dial error got nil")
} else {
if !errors.Is(err, ErrXOauth2AuthNotSupported) {
t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err)
}
}
if !errors.Is(err, ErrXOauth2AuthNotSupported) {
t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err)
}
if err = c.Close(); err != nil {
t.Fatalf("disconnect from test server failed: %v", err)
}
client := strings.Split(wrote.String(), "\r\n")
if len(client) != 3 {
t.Fatalf("unexpected number of client requests got %d; want 3", len(client))
if len(client) != 2 {
t.Fatalf("unexpected number of client requests got %d; want 2", len(client))
}
if !strings.HasPrefix(client[0], "EHLO") {
t.Fatalf("expected EHLO, got %q", client[0])
}
if client[1] != "QUIT" {
t.Fatalf("expected QUIT, got %q", client[3])
}
})
}
@ -3652,6 +3757,17 @@ func (f faker) SetDeadline(time.Time) error { return nil }
func (f faker) SetReadDeadline(time.Time) error { return nil }
func (f faker) SetWriteDeadline(time.Time) error { return nil }
type testSender struct {
client *Client
}
func (t *testSender) Send(ctx context.Context, m *Msg) error {
if err := t.client.DialAndSendWithContext(ctx, m); err != nil {
return fmt.Errorf("failed to dial and send mail: %w", err)
}
return nil
}
// parseJSONLog parses a JSON encoded log from the provided buffer and returns a slice of logLine structs.
// In case of a decode error, it reports the error to the testing framework.
func parseJSONLog(t *testing.T, buf *bytes.Buffer) logData {

View file

@ -601,12 +601,16 @@ func (c *Client) SetLogAuthData() {
// SetDSNMailReturnOption sets the DSN mail return option for the Mail method
func (c *Client) SetDSNMailReturnOption(d string) {
c.mutex.Lock()
c.dsnmrtype = d
c.mutex.Unlock()
}
// SetDSNRcptNotifyOption sets the DSN recipient notify option for the Mail method
func (c *Client) SetDSNRcptNotifyOption(d string) {
c.mutex.Lock()
c.dsnrntype = d
c.mutex.Unlock()
}
// HasConnection checks if the client has an active connection.