Refactor SMTP client functions to improve modularity

Refactor `DialWithContext` to delegate client creation to `SMTPClientFromDialWithContext`. Add new methods `SendWithSMTPClient`, `CloseWithSMTPClient`, and `ResetWithSMTPClient` to handle specific actions on `smtp.Client`. Simplify `auth`, `sendSingleMsg`, and `tls` methods by passing client as parameters.
This commit is contained in:
Winni Neessen 2024-11-22 14:59:19 +01:00
parent f9e869061e
commit 55884786be
Signed by: wneessen
GPG key ID: 385AC9889632126E
2 changed files with 180 additions and 95 deletions

261
client.go
View file

@ -142,9 +142,6 @@ type (
// host is the hostname of the SMTP server we are connecting to. // host is the hostname of the SMTP server we are connecting to.
host string host string
// isEncrypted indicates wether the Client connection is encrypted or not.
isEncrypted bool
// logAuthData indicates whether authentication-related data should be logged. // logAuthData indicates whether authentication-related data should be logged.
logAuthData bool logAuthData bool
@ -931,62 +928,131 @@ func (c *Client) SetLogAuthData(logAuth bool) {
// Returns: // Returns:
// - An error if the connection to the SMTP server fails or any subsequent command fails. // - An error if the connection to the SMTP server fails or any subsequent command fails.
func (c *Client) DialWithContext(dialCtx context.Context) error { func (c *Client) DialWithContext(dialCtx context.Context) error {
c.mutex.Lock() client, err := c.SMTPClientFromDialWithContext(dialCtx)
defer c.mutex.Unlock()
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
defer cancel()
if c.dialContextFunc == nil {
netDialer := net.Dialer{}
c.dialContextFunc = netDialer.DialContext
if c.useSSL {
tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig}
c.isEncrypted = true
c.dialContextFunc = tlsDialer.DialContext
}
}
connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock()
c.smtpClient = client
c.mutex.Unlock()
/*
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
defer cancel()
isEncrypted := false
if c.dialContextFunc == nil {
netDialer := net.Dialer{}
c.dialContextFunc = netDialer.DialContext
if c.useSSL {
tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig}
isEncrypted = true
c.dialContextFunc = tlsDialer.DialContext
}
}
connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
if err != nil {
return err
}
client, err := smtp.NewClient(connection, c.host)
if err != nil {
return err
}
if client == nil {
return fmt.Errorf("SMTP client is nil")
}
c.smtpClient = client
if c.logger != nil {
c.smtpClient.SetLogger(c.logger)
}
if c.useDebugLog {
c.smtpClient.SetDebugLog(true)
}
if c.logAuthData {
c.smtpClient.SetLogAuthData()
}
if err = c.smtpClient.Hello(c.helo); err != nil {
return err
}
if err = c.tls(c.smtpClient, &isEncrypted); err != nil {
return err
}
if err = c.auth(c.smtpClient, isEncrypted); err != nil {
return err
}
*/
return nil
}
// SMTPClientFromDialWithContext is similar to DialWithContext but instead of storing the smtp.Client
// on the Client it will return the smtp.Client instead.
func (c *Client) SMTPClientFromDialWithContext(ctxDial context.Context) (*smtp.Client, error) {
c.mutex.RLock()
defer c.mutex.RUnlock()
ctx, cancel := context.WithDeadline(ctxDial, time.Now().Add(c.connTimeout))
defer cancel()
isEncrypted := false
dialContextFunc := c.dialContextFunc
if c.dialContextFunc == nil {
netDialer := net.Dialer{}
dialContextFunc = netDialer.DialContext
if c.useSSL {
tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig}
isEncrypted = true
dialContextFunc = tlsDialer.DialContext
}
}
connection, err := dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
connection, err = dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
if err != nil {
return nil, err
}
client, err := smtp.NewClient(connection, c.host) client, err := smtp.NewClient(connection, c.host)
if err != nil { if err != nil {
return err return nil, err
} }
if client == nil { if client == nil {
return fmt.Errorf("SMTP client is nil") return nil, fmt.Errorf("SMTP client is nil")
} }
c.smtpClient = client
if c.logger != nil { if c.logger != nil {
c.smtpClient.SetLogger(c.logger) client.SetLogger(c.logger)
} }
if c.useDebugLog { if c.useDebugLog {
c.smtpClient.SetDebugLog(true) client.SetDebugLog(true)
} }
if c.logAuthData { if c.logAuthData {
c.smtpClient.SetLogAuthData() client.SetLogAuthData()
} }
if err = c.smtpClient.Hello(c.helo); err != nil { if err = client.Hello(c.helo); err != nil {
return err return nil, err
} }
if err = c.tls(); err != nil { if err = c.tls(client, &isEncrypted); err != nil {
return err return nil, err
} }
if err = c.auth(); err != nil { if err = c.auth(client, isEncrypted); err != nil {
return err return nil, err
} }
return nil return client, nil
} }
// Close terminates the connection to the SMTP server, returning an error if the disconnection // Close terminates the connection to the SMTP server, returning an error if the disconnection
@ -999,10 +1065,14 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
// Returns: // Returns:
// - An error if the disconnection fails; otherwise, returns nil. // - An error if the disconnection fails; otherwise, returns nil.
func (c *Client) Close() error { func (c *Client) Close() error {
if c.smtpClient == nil || !c.smtpClient.HasConnection() { return c.CloseWithSMTPClient(c.smtpClient)
}
func (c *Client) CloseWithSMTPClient(client *smtp.Client) error {
if client == nil || !client.HasConnection() {
return nil return nil
} }
if err := c.smtpClient.Quit(); err != nil { if err := client.Quit(); err != nil {
return fmt.Errorf("failed to close SMTP client: %w", err) return fmt.Errorf("failed to close SMTP client: %w", err)
} }
@ -1018,10 +1088,14 @@ func (c *Client) Close() error {
// Returns: // Returns:
// - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil. // - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil.
func (c *Client) Reset() error { func (c *Client) Reset() error {
if err := c.checkConn(); err != nil { return c.ResetWithSMTPClient(c.smtpClient)
}
func (c *Client) ResetWithSMTPClient(client *smtp.Client) error {
if err := c.checkConn(client); err != nil {
return err return err
} }
if err := c.smtpClient.Reset(); err != nil { if err := client.Reset(); err != nil {
return fmt.Errorf("failed to send RSET to SMTP client: %w", err) return fmt.Errorf("failed to send RSET to SMTP client: %w", err)
} }
@ -1061,19 +1135,20 @@ func (c *Client) DialAndSend(messages ...*Msg) error {
// - An error if the connection fails, if sending the messages fails, or if closing the // - An error if the connection fails, if sending the messages fails, or if closing the
// connection fails; otherwise, returns nil. // connection fails; otherwise, returns nil.
func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error { func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error {
c.sendMutex.Lock() //c.sendMutex.Lock()
defer c.sendMutex.Unlock() //defer c.sendMutex.Unlock()
if err := c.DialWithContext(ctx); err != nil { client, err := c.SMTPClientFromDialWithContext(ctx)
if err != nil {
return fmt.Errorf("dial failed: %w", err) return fmt.Errorf("dial failed: %w", err)
} }
defer func() { defer func() {
_ = c.Close() _ = c.CloseWithSMTPClient(client)
}() }()
if err := c.Send(messages...); err != nil { if err := c.SendWithSMTPClient(client, messages...); err != nil {
return fmt.Errorf("send failed: %w", err) return fmt.Errorf("send failed: %w", err)
} }
if err := c.Close(); err != nil { if err := c.CloseWithSMTPClient(client); err != nil {
return fmt.Errorf("failed to close connection: %w", err) return fmt.Errorf("failed to close connection: %w", err)
} }
return nil return nil
@ -1098,16 +1173,17 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
// Returns: // Returns:
// - An error if the connection check fails, if no supported authentication method is found, // - An error if the connection check fails, if no supported authentication method is found,
// or if the authentication process fails. // or if the authentication process fails.
func (c *Client) auth() error { func (c *Client) auth(client *smtp.Client, isEnc bool) error {
var smtpAuth smtp.Auth
if c.smtpAuth == nil && c.smtpAuthType != SMTPAuthNoAuth { if c.smtpAuth == nil && c.smtpAuthType != SMTPAuthNoAuth {
hasSMTPAuth, smtpAuthType := c.smtpClient.Extension("AUTH") hasSMTPAuth, smtpAuthType := client.Extension("AUTH")
if !hasSMTPAuth { if !hasSMTPAuth {
return fmt.Errorf("server does not support SMTP AUTH") return fmt.Errorf("server does not support SMTP AUTH")
} }
authType := c.smtpAuthType authType := c.smtpAuthType
if c.smtpAuthType == SMTPAuthAutoDiscover { if c.smtpAuthType == SMTPAuthAutoDiscover {
discoveredType, err := c.authTypeAutoDiscover(smtpAuthType) discoveredType, err := c.authTypeAutoDiscover(smtpAuthType, isEnc)
if err != nil { if err != nil {
return err return err
} }
@ -1119,74 +1195,74 @@ func (c *Client) auth() error {
if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) { if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) {
return ErrPlainAuthNotSupported return ErrPlainAuthNotSupported
} }
c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false) smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false)
case SMTPAuthPlainNoEnc: case SMTPAuthPlainNoEnc:
if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) { if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) {
return ErrPlainAuthNotSupported return ErrPlainAuthNotSupported
} }
c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true) smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true)
case SMTPAuthLogin: case SMTPAuthLogin:
if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) { if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) {
return ErrLoginAuthNotSupported return ErrLoginAuthNotSupported
} }
c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false) smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false)
case SMTPAuthLoginNoEnc: case SMTPAuthLoginNoEnc:
if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) { if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) {
return ErrLoginAuthNotSupported return ErrLoginAuthNotSupported
} }
c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true) smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true)
case SMTPAuthCramMD5: case SMTPAuthCramMD5:
if !strings.Contains(smtpAuthType, string(SMTPAuthCramMD5)) { if !strings.Contains(smtpAuthType, string(SMTPAuthCramMD5)) {
return ErrCramMD5AuthNotSupported return ErrCramMD5AuthNotSupported
} }
c.smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass) smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass)
case SMTPAuthXOAUTH2: case SMTPAuthXOAUTH2:
if !strings.Contains(smtpAuthType, string(SMTPAuthXOAUTH2)) { if !strings.Contains(smtpAuthType, string(SMTPAuthXOAUTH2)) {
return ErrXOauth2AuthNotSupported return ErrXOauth2AuthNotSupported
} }
c.smtpAuth = smtp.XOAuth2Auth(c.user, c.pass) smtpAuth = smtp.XOAuth2Auth(c.user, c.pass)
case SMTPAuthSCRAMSHA1: case SMTPAuthSCRAMSHA1:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1)) { if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1)) {
return ErrSCRAMSHA1AuthNotSupported return ErrSCRAMSHA1AuthNotSupported
} }
c.smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass) smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass)
case SMTPAuthSCRAMSHA256: case SMTPAuthSCRAMSHA256:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256)) { if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256)) {
return ErrSCRAMSHA256AuthNotSupported return ErrSCRAMSHA256AuthNotSupported
} }
c.smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass) smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass)
case SMTPAuthSCRAMSHA1PLUS: case SMTPAuthSCRAMSHA1PLUS:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1PLUS)) { if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1PLUS)) {
return ErrSCRAMSHA1PLUSAuthNotSupported return ErrSCRAMSHA1PLUSAuthNotSupported
} }
tlsConnState, err := c.smtpClient.GetTLSConnectionState() tlsConnState, err := client.GetTLSConnectionState()
if err != nil { if err != nil {
return err return err
} }
c.smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState) smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState)
case SMTPAuthSCRAMSHA256PLUS: case SMTPAuthSCRAMSHA256PLUS:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256PLUS)) { if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256PLUS)) {
return ErrSCRAMSHA256PLUSAuthNotSupported return ErrSCRAMSHA256PLUSAuthNotSupported
} }
tlsConnState, err := c.smtpClient.GetTLSConnectionState() tlsConnState, err := client.GetTLSConnectionState()
if err != nil { if err != nil {
return err return err
} }
c.smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState) smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState)
default: default:
return fmt.Errorf("unsupported SMTP AUTH type %q", c.smtpAuthType) return fmt.Errorf("unsupported SMTP AUTH type %q", c.smtpAuthType)
} }
} }
if c.smtpAuth != nil { if smtpAuth != nil {
if err := c.smtpClient.Auth(c.smtpAuth); err != nil { if err := client.Auth(smtpAuth); err != nil {
return fmt.Errorf("SMTP AUTH failed: %w", err) return fmt.Errorf("SMTP AUTH failed: %w", err)
} }
} }
return nil return nil
} }
func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) { func (c *Client) authTypeAutoDiscover(supported string, isEnc bool) (SMTPAuthType, error) {
if supported == "" { if supported == "" {
return "", ErrNoSupportedAuthDiscovered return "", ErrNoSupportedAuthDiscovered
} }
@ -1194,7 +1270,7 @@ func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) {
SMTPAuthSCRAMSHA256PLUS, SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1PLUS, SMTPAuthSCRAMSHA1, SMTPAuthSCRAMSHA256PLUS, SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1PLUS, SMTPAuthSCRAMSHA1,
SMTPAuthXOAUTH2, SMTPAuthCramMD5, SMTPAuthPlain, SMTPAuthLogin, SMTPAuthXOAUTH2, SMTPAuthCramMD5, SMTPAuthPlain, SMTPAuthLogin,
} }
if !c.isEncrypted { if !isEnc {
preferList = []SMTPAuthType{SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5} preferList = []SMTPAuthType{SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5}
} }
mechs := strings.Split(supported, " ") mechs := strings.Split(supported, " ")
@ -1231,13 +1307,13 @@ func sliceContains(slice []string, item string) bool {
// //
// Returns: // Returns:
// - An error if any part of the sending process fails; otherwise, returns nil. // - An error if any part of the sending process fails; otherwise, returns nil.
func (c *Client) sendSingleMsg(message *Msg) error { func (c *Client) sendSingleMsg(client *smtp.Client, message *Msg) error {
c.mutex.Lock() c.mutex.RLock()
defer c.mutex.Unlock() defer c.mutex.RUnlock()
escSupport, _ := c.smtpClient.Extension("ENHANCEDSTATUSCODES") escSupport, _ := client.Extension("ENHANCEDSTATUSCODES")
if message.encoding == NoEncoding { if message.encoding == NoEncoding {
if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok { if ok, _ := client.Extension("8BITMIME"); !ok {
return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message} return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message}
} }
} }
@ -1260,16 +1336,16 @@ func (c *Client) sendSingleMsg(message *Msg) error {
if c.requestDSN { if c.requestDSN {
if c.dsnReturnType != "" { if c.dsnReturnType != "" {
c.smtpClient.SetDSNMailReturnOption(string(c.dsnReturnType)) client.SetDSNMailReturnOption(string(c.dsnReturnType))
} }
} }
if err = c.smtpClient.Mail(from); err != nil { if err = client.Mail(from); err != nil {
retError := &SendError{ retError := &SendError{
Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err), Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err),
affectedMsg: message, errcode: errorCode(err), affectedMsg: message, errcode: errorCode(err),
enhancedStatusCode: enhancedStatusCode(err, escSupport), enhancedStatusCode: enhancedStatusCode(err, escSupport),
} }
if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil { if resetSendErr := client.Reset(); resetSendErr != nil {
retError.errlist = append(retError.errlist, resetSendErr) retError.errlist = append(retError.errlist, resetSendErr)
} }
return retError return retError
@ -1279,9 +1355,9 @@ func (c *Client) sendSingleMsg(message *Msg) error {
rcptSendErr.errlist = make([]error, 0) rcptSendErr.errlist = make([]error, 0)
rcptSendErr.rcpt = make([]string, 0) rcptSendErr.rcpt = make([]string, 0)
rcptNotifyOpt := strings.Join(c.dsnRcptNotifyType, ",") rcptNotifyOpt := strings.Join(c.dsnRcptNotifyType, ",")
c.smtpClient.SetDSNRcptNotifyOption(rcptNotifyOpt) client.SetDSNRcptNotifyOption(rcptNotifyOpt)
for _, rcpt := range rcpts { for _, rcpt := range rcpts {
if err = c.smtpClient.Rcpt(rcpt); err != nil { if err = client.Rcpt(rcpt); err != nil {
rcptSendErr.Reason = ErrSMTPRcptTo rcptSendErr.Reason = ErrSMTPRcptTo
rcptSendErr.errlist = append(rcptSendErr.errlist, err) rcptSendErr.errlist = append(rcptSendErr.errlist, err)
rcptSendErr.rcpt = append(rcptSendErr.rcpt, rcpt) rcptSendErr.rcpt = append(rcptSendErr.rcpt, rcpt)
@ -1292,12 +1368,12 @@ func (c *Client) sendSingleMsg(message *Msg) error {
} }
} }
if hasError { if hasError {
if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil { if resetSendErr := client.Reset(); resetSendErr != nil {
rcptSendErr.errlist = append(rcptSendErr.errlist, resetSendErr) rcptSendErr.errlist = append(rcptSendErr.errlist, resetSendErr)
} }
return rcptSendErr return rcptSendErr
} }
writer, err := c.smtpClient.Data() writer, err := client.Data()
if err != nil { if err != nil {
return &SendError{ return &SendError{
Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err), Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err),
@ -1322,7 +1398,7 @@ func (c *Client) sendSingleMsg(message *Msg) error {
} }
message.isDelivered = true message.isDelivered = true
if err = c.Reset(); err != nil { if err = c.ResetWithSMTPClient(client); err != nil {
return &SendError{ return &SendError{
Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err), Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err),
affectedMsg: message, errcode: errorCode(err), affectedMsg: message, errcode: errorCode(err),
@ -1344,21 +1420,24 @@ func (c *Client) sendSingleMsg(message *Msg) error {
// Returns: // Returns:
// - An error if there is no active connection, if the NOOP command fails, or if extending // - An error if there is no active connection, if the NOOP command fails, or if extending
// the deadline fails; otherwise, returns nil. // the deadline fails; otherwise, returns nil.
func (c *Client) checkConn() error { func (c *Client) checkConn(client *smtp.Client) error {
if c.smtpClient == nil { if client == nil {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
if !c.smtpClient.HasConnection() { if !client.HasConnection() {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
if !c.noNoop { c.mutex.RLock()
if err := c.smtpClient.Noop(); err != nil { noNoop := c.noNoop
c.mutex.RUnlock()
if !noNoop {
if err := client.Noop(); err != nil {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
} }
if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil { if err := client.UpdateDeadline(c.connTimeout); err != nil {
return ErrDeadlineExtendFailed return ErrDeadlineExtendFailed
} }
return nil return nil
@ -1405,10 +1484,10 @@ func (c *Client) setDefaultHelo() error {
// Returns: // Returns:
// - An error if there is no active connection, if STARTTLS is required but not supported, // - An error if there is no active connection, if STARTTLS is required but not supported,
// or if there are issues during the TLS handshake; otherwise, returns nil. // or if there are issues during the TLS handshake; otherwise, returns nil.
func (c *Client) tls() error { func (c *Client) tls(client *smtp.Client, isEnc *bool) error {
if !c.useSSL && c.tlspolicy != NoTLS { if !c.useSSL && c.tlspolicy != NoTLS {
hasStartTLS := false hasStartTLS := false
extension, _ := c.smtpClient.Extension("STARTTLS") extension, _ := client.Extension("STARTTLS")
if c.tlspolicy == TLSMandatory { if c.tlspolicy == TLSMandatory {
hasStartTLS = true hasStartTLS = true
if !extension { if !extension {
@ -1422,21 +1501,21 @@ func (c *Client) tls() error {
} }
} }
if hasStartTLS { if hasStartTLS {
if err := c.smtpClient.StartTLS(c.tlsconfig); err != nil { if err := client.StartTLS(c.tlsconfig); err != nil {
return err return err
} }
} }
tlsConnState, err := c.smtpClient.GetTLSConnectionState() tlsConnState, err := client.GetTLSConnectionState()
if err != nil { if err != nil {
switch { switch {
case errors.Is(err, smtp.ErrNonTLSConnection): case errors.Is(err, smtp.ErrNonTLSConnection):
c.isEncrypted = false *isEnc = false
return nil return nil
default: default:
return fmt.Errorf("failed to get TLS connection state: %w", err) return fmt.Errorf("failed to get TLS connection state: %w", err)
} }
} }
c.isEncrypted = tlsConnState.HandshakeComplete *isEnc = tlsConnState.HandshakeComplete
} }
return nil return nil
} }

View file

@ -9,6 +9,8 @@ package mail
import ( import (
"errors" "errors"
"github.com/wneessen/go-mail/smtp"
) )
// Send attempts to send one or more Msg using the Client connection to the SMTP server. // Send attempts to send one or more Msg using the Client connection to the SMTP server.
@ -27,11 +29,15 @@ import (
// Returns: // Returns:
// - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil. // - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil.
func (c *Client) Send(messages ...*Msg) (returnErr error) { func (c *Client) Send(messages ...*Msg) (returnErr error) {
return c.SendWithSMTPClient(c.smtpClient, messages...)
}
func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) (returnErr error) {
escSupport := false escSupport := false
if c.smtpClient != nil { if client != nil {
escSupport, _ = c.smtpClient.Extension("ENHANCEDSTATUSCODES") escSupport, _ = client.Extension("ENHANCEDSTATUSCODES")
} }
if err := c.checkConn(); err != nil { if err := c.checkConn(client); err != nil {
returnErr = &SendError{ returnErr = &SendError{
Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err), Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err),
errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport),
@ -45,7 +51,7 @@ func (c *Client) Send(messages ...*Msg) (returnErr error) {
}() }()
for id, message := range messages { for id, message := range messages {
if sendErr := c.sendSingleMsg(message); sendErr != nil { if sendErr := c.sendSingleMsg(client, message); sendErr != nil {
messages[id].sendError = sendErr messages[id].sendError = sendErr
errs = append(errs, sendErr) errs = append(errs, sendErr)
} }