Refactor SMTP client functions to improve modularity

Refactor `DialWithContext` to delegate client creation to `SMTPClientFromDialWithContext`. Add new methods `SendWithSMTPClient`, `CloseWithSMTPClient`, and `ResetWithSMTPClient` to handle specific actions on `smtp.Client`. Simplify `auth`, `sendSingleMsg`, and `tls` methods by passing client as parameters.
This commit is contained in:
Winni Neessen 2024-11-22 14:59:19 +01:00
parent f9e869061e
commit 55884786be
Signed by: wneessen
GPG key ID: 385AC9889632126E
2 changed files with 180 additions and 95 deletions

203
client.go
View file

@ -142,9 +142,6 @@ type (
// host is the hostname of the SMTP server we are connecting to.
host string
// isEncrypted indicates wether the Client connection is encrypted or not.
isEncrypted bool
// logAuthData indicates whether authentication-related data should be logged.
logAuthData bool
@ -931,19 +928,25 @@ func (c *Client) SetLogAuthData(logAuth bool) {
// Returns:
// - An error if the connection to the SMTP server fails or any subsequent command fails.
func (c *Client) DialWithContext(dialCtx context.Context) error {
client, err := c.SMTPClientFromDialWithContext(dialCtx)
if err != nil {
return err
}
c.mutex.Lock()
defer c.mutex.Unlock()
c.smtpClient = client
c.mutex.Unlock()
/*
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
defer cancel()
isEncrypted := false
if c.dialContextFunc == nil {
netDialer := net.Dialer{}
c.dialContextFunc = netDialer.DialContext
if c.useSSL {
tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig}
c.isEncrypted = true
isEncrypted = true
c.dialContextFunc = tlsDialer.DialContext
}
}
@ -978,17 +981,80 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
return err
}
if err = c.tls(); err != nil {
if err = c.tls(c.smtpClient, &isEncrypted); err != nil {
return err
}
if err = c.auth(); err != nil {
if err = c.auth(c.smtpClient, isEncrypted); err != nil {
return err
}
*/
return nil
}
// SMTPClientFromDialWithContext is similar to DialWithContext but instead of storing the smtp.Client
// on the Client it will return the smtp.Client instead.
func (c *Client) SMTPClientFromDialWithContext(ctxDial context.Context) (*smtp.Client, error) {
c.mutex.RLock()
defer c.mutex.RUnlock()
ctx, cancel := context.WithDeadline(ctxDial, time.Now().Add(c.connTimeout))
defer cancel()
isEncrypted := false
dialContextFunc := c.dialContextFunc
if c.dialContextFunc == nil {
netDialer := net.Dialer{}
dialContextFunc = netDialer.DialContext
if c.useSSL {
tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig}
isEncrypted = true
dialContextFunc = tlsDialer.DialContext
}
}
connection, err := dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
connection, err = dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
if err != nil {
return nil, err
}
client, err := smtp.NewClient(connection, c.host)
if err != nil {
return nil, err
}
if client == nil {
return nil, fmt.Errorf("SMTP client is nil")
}
if c.logger != nil {
client.SetLogger(c.logger)
}
if c.useDebugLog {
client.SetDebugLog(true)
}
if c.logAuthData {
client.SetLogAuthData()
}
if err = client.Hello(c.helo); err != nil {
return nil, err
}
if err = c.tls(client, &isEncrypted); err != nil {
return nil, err
}
if err = c.auth(client, isEncrypted); err != nil {
return nil, err
}
return client, nil
}
// Close terminates the connection to the SMTP server, returning an error if the disconnection
// fails. If the connection is already closed, this method is a no-op and disregards any error.
//
@ -999,10 +1065,14 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
// Returns:
// - An error if the disconnection fails; otherwise, returns nil.
func (c *Client) Close() error {
if c.smtpClient == nil || !c.smtpClient.HasConnection() {
return c.CloseWithSMTPClient(c.smtpClient)
}
func (c *Client) CloseWithSMTPClient(client *smtp.Client) error {
if client == nil || !client.HasConnection() {
return nil
}
if err := c.smtpClient.Quit(); err != nil {
if err := client.Quit(); err != nil {
return fmt.Errorf("failed to close SMTP client: %w", err)
}
@ -1018,10 +1088,14 @@ func (c *Client) Close() error {
// Returns:
// - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil.
func (c *Client) Reset() error {
if err := c.checkConn(); err != nil {
return c.ResetWithSMTPClient(c.smtpClient)
}
func (c *Client) ResetWithSMTPClient(client *smtp.Client) error {
if err := c.checkConn(client); err != nil {
return err
}
if err := c.smtpClient.Reset(); err != nil {
if err := client.Reset(); err != nil {
return fmt.Errorf("failed to send RSET to SMTP client: %w", err)
}
@ -1061,19 +1135,20 @@ func (c *Client) DialAndSend(messages ...*Msg) error {
// - An error if the connection fails, if sending the messages fails, or if closing the
// connection fails; otherwise, returns nil.
func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error {
c.sendMutex.Lock()
defer c.sendMutex.Unlock()
if err := c.DialWithContext(ctx); err != nil {
//c.sendMutex.Lock()
//defer c.sendMutex.Unlock()
client, err := c.SMTPClientFromDialWithContext(ctx)
if err != nil {
return fmt.Errorf("dial failed: %w", err)
}
defer func() {
_ = c.Close()
_ = c.CloseWithSMTPClient(client)
}()
if err := c.Send(messages...); err != nil {
if err := c.SendWithSMTPClient(client, messages...); err != nil {
return fmt.Errorf("send failed: %w", err)
}
if err := c.Close(); err != nil {
if err := c.CloseWithSMTPClient(client); err != nil {
return fmt.Errorf("failed to close connection: %w", err)
}
return nil
@ -1098,16 +1173,17 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
// Returns:
// - An error if the connection check fails, if no supported authentication method is found,
// or if the authentication process fails.
func (c *Client) auth() error {
func (c *Client) auth(client *smtp.Client, isEnc bool) error {
var smtpAuth smtp.Auth
if c.smtpAuth == nil && c.smtpAuthType != SMTPAuthNoAuth {
hasSMTPAuth, smtpAuthType := c.smtpClient.Extension("AUTH")
hasSMTPAuth, smtpAuthType := client.Extension("AUTH")
if !hasSMTPAuth {
return fmt.Errorf("server does not support SMTP AUTH")
}
authType := c.smtpAuthType
if c.smtpAuthType == SMTPAuthAutoDiscover {
discoveredType, err := c.authTypeAutoDiscover(smtpAuthType)
discoveredType, err := c.authTypeAutoDiscover(smtpAuthType, isEnc)
if err != nil {
return err
}
@ -1119,74 +1195,74 @@ func (c *Client) auth() error {
if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) {
return ErrPlainAuthNotSupported
}
c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false)
smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false)
case SMTPAuthPlainNoEnc:
if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) {
return ErrPlainAuthNotSupported
}
c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true)
smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true)
case SMTPAuthLogin:
if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) {
return ErrLoginAuthNotSupported
}
c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false)
smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false)
case SMTPAuthLoginNoEnc:
if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) {
return ErrLoginAuthNotSupported
}
c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true)
smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true)
case SMTPAuthCramMD5:
if !strings.Contains(smtpAuthType, string(SMTPAuthCramMD5)) {
return ErrCramMD5AuthNotSupported
}
c.smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass)
smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass)
case SMTPAuthXOAUTH2:
if !strings.Contains(smtpAuthType, string(SMTPAuthXOAUTH2)) {
return ErrXOauth2AuthNotSupported
}
c.smtpAuth = smtp.XOAuth2Auth(c.user, c.pass)
smtpAuth = smtp.XOAuth2Auth(c.user, c.pass)
case SMTPAuthSCRAMSHA1:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1)) {
return ErrSCRAMSHA1AuthNotSupported
}
c.smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass)
smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass)
case SMTPAuthSCRAMSHA256:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256)) {
return ErrSCRAMSHA256AuthNotSupported
}
c.smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass)
smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass)
case SMTPAuthSCRAMSHA1PLUS:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1PLUS)) {
return ErrSCRAMSHA1PLUSAuthNotSupported
}
tlsConnState, err := c.smtpClient.GetTLSConnectionState()
tlsConnState, err := client.GetTLSConnectionState()
if err != nil {
return err
}
c.smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState)
smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState)
case SMTPAuthSCRAMSHA256PLUS:
if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256PLUS)) {
return ErrSCRAMSHA256PLUSAuthNotSupported
}
tlsConnState, err := c.smtpClient.GetTLSConnectionState()
tlsConnState, err := client.GetTLSConnectionState()
if err != nil {
return err
}
c.smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState)
smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState)
default:
return fmt.Errorf("unsupported SMTP AUTH type %q", c.smtpAuthType)
}
}
if c.smtpAuth != nil {
if err := c.smtpClient.Auth(c.smtpAuth); err != nil {
if smtpAuth != nil {
if err := client.Auth(smtpAuth); err != nil {
return fmt.Errorf("SMTP AUTH failed: %w", err)
}
}
return nil
}
func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) {
func (c *Client) authTypeAutoDiscover(supported string, isEnc bool) (SMTPAuthType, error) {
if supported == "" {
return "", ErrNoSupportedAuthDiscovered
}
@ -1194,7 +1270,7 @@ func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) {
SMTPAuthSCRAMSHA256PLUS, SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1PLUS, SMTPAuthSCRAMSHA1,
SMTPAuthXOAUTH2, SMTPAuthCramMD5, SMTPAuthPlain, SMTPAuthLogin,
}
if !c.isEncrypted {
if !isEnc {
preferList = []SMTPAuthType{SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5}
}
mechs := strings.Split(supported, " ")
@ -1231,13 +1307,13 @@ func sliceContains(slice []string, item string) bool {
//
// Returns:
// - An error if any part of the sending process fails; otherwise, returns nil.
func (c *Client) sendSingleMsg(message *Msg) error {
c.mutex.Lock()
defer c.mutex.Unlock()
escSupport, _ := c.smtpClient.Extension("ENHANCEDSTATUSCODES")
func (c *Client) sendSingleMsg(client *smtp.Client, message *Msg) error {
c.mutex.RLock()
defer c.mutex.RUnlock()
escSupport, _ := client.Extension("ENHANCEDSTATUSCODES")
if message.encoding == NoEncoding {
if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok {
if ok, _ := client.Extension("8BITMIME"); !ok {
return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message}
}
}
@ -1260,16 +1336,16 @@ func (c *Client) sendSingleMsg(message *Msg) error {
if c.requestDSN {
if c.dsnReturnType != "" {
c.smtpClient.SetDSNMailReturnOption(string(c.dsnReturnType))
client.SetDSNMailReturnOption(string(c.dsnReturnType))
}
}
if err = c.smtpClient.Mail(from); err != nil {
if err = client.Mail(from); err != nil {
retError := &SendError{
Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err),
affectedMsg: message, errcode: errorCode(err),
enhancedStatusCode: enhancedStatusCode(err, escSupport),
}
if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil {
if resetSendErr := client.Reset(); resetSendErr != nil {
retError.errlist = append(retError.errlist, resetSendErr)
}
return retError
@ -1279,9 +1355,9 @@ func (c *Client) sendSingleMsg(message *Msg) error {
rcptSendErr.errlist = make([]error, 0)
rcptSendErr.rcpt = make([]string, 0)
rcptNotifyOpt := strings.Join(c.dsnRcptNotifyType, ",")
c.smtpClient.SetDSNRcptNotifyOption(rcptNotifyOpt)
client.SetDSNRcptNotifyOption(rcptNotifyOpt)
for _, rcpt := range rcpts {
if err = c.smtpClient.Rcpt(rcpt); err != nil {
if err = client.Rcpt(rcpt); err != nil {
rcptSendErr.Reason = ErrSMTPRcptTo
rcptSendErr.errlist = append(rcptSendErr.errlist, err)
rcptSendErr.rcpt = append(rcptSendErr.rcpt, rcpt)
@ -1292,12 +1368,12 @@ func (c *Client) sendSingleMsg(message *Msg) error {
}
}
if hasError {
if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil {
if resetSendErr := client.Reset(); resetSendErr != nil {
rcptSendErr.errlist = append(rcptSendErr.errlist, resetSendErr)
}
return rcptSendErr
}
writer, err := c.smtpClient.Data()
writer, err := client.Data()
if err != nil {
return &SendError{
Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err),
@ -1322,7 +1398,7 @@ func (c *Client) sendSingleMsg(message *Msg) error {
}
message.isDelivered = true
if err = c.Reset(); err != nil {
if err = c.ResetWithSMTPClient(client); err != nil {
return &SendError{
Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err),
affectedMsg: message, errcode: errorCode(err),
@ -1344,21 +1420,24 @@ func (c *Client) sendSingleMsg(message *Msg) error {
// Returns:
// - An error if there is no active connection, if the NOOP command fails, or if extending
// the deadline fails; otherwise, returns nil.
func (c *Client) checkConn() error {
if c.smtpClient == nil {
func (c *Client) checkConn(client *smtp.Client) error {
if client == nil {
return ErrNoActiveConnection
}
if !c.smtpClient.HasConnection() {
if !client.HasConnection() {
return ErrNoActiveConnection
}
if !c.noNoop {
if err := c.smtpClient.Noop(); err != nil {
c.mutex.RLock()
noNoop := c.noNoop
c.mutex.RUnlock()
if !noNoop {
if err := client.Noop(); err != nil {
return ErrNoActiveConnection
}
}
if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil {
if err := client.UpdateDeadline(c.connTimeout); err != nil {
return ErrDeadlineExtendFailed
}
return nil
@ -1405,10 +1484,10 @@ func (c *Client) setDefaultHelo() error {
// Returns:
// - An error if there is no active connection, if STARTTLS is required but not supported,
// or if there are issues during the TLS handshake; otherwise, returns nil.
func (c *Client) tls() error {
func (c *Client) tls(client *smtp.Client, isEnc *bool) error {
if !c.useSSL && c.tlspolicy != NoTLS {
hasStartTLS := false
extension, _ := c.smtpClient.Extension("STARTTLS")
extension, _ := client.Extension("STARTTLS")
if c.tlspolicy == TLSMandatory {
hasStartTLS = true
if !extension {
@ -1422,21 +1501,21 @@ func (c *Client) tls() error {
}
}
if hasStartTLS {
if err := c.smtpClient.StartTLS(c.tlsconfig); err != nil {
if err := client.StartTLS(c.tlsconfig); err != nil {
return err
}
}
tlsConnState, err := c.smtpClient.GetTLSConnectionState()
tlsConnState, err := client.GetTLSConnectionState()
if err != nil {
switch {
case errors.Is(err, smtp.ErrNonTLSConnection):
c.isEncrypted = false
*isEnc = false
return nil
default:
return fmt.Errorf("failed to get TLS connection state: %w", err)
}
}
c.isEncrypted = tlsConnState.HandshakeComplete
*isEnc = tlsConnState.HandshakeComplete
}
return nil
}

View file

@ -9,6 +9,8 @@ package mail
import (
"errors"
"github.com/wneessen/go-mail/smtp"
)
// Send attempts to send one or more Msg using the Client connection to the SMTP server.
@ -27,11 +29,15 @@ import (
// Returns:
// - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil.
func (c *Client) Send(messages ...*Msg) (returnErr error) {
return c.SendWithSMTPClient(c.smtpClient, messages...)
}
func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) (returnErr error) {
escSupport := false
if c.smtpClient != nil {
escSupport, _ = c.smtpClient.Extension("ENHANCEDSTATUSCODES")
if client != nil {
escSupport, _ = client.Extension("ENHANCEDSTATUSCODES")
}
if err := c.checkConn(); err != nil {
if err := c.checkConn(client); err != nil {
returnErr = &SendError{
Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err),
errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport),
@ -45,7 +51,7 @@ func (c *Client) Send(messages ...*Msg) (returnErr error) {
}()
for id, message := range messages {
if sendErr := c.sendSingleMsg(message); sendErr != nil {
if sendErr := c.sendSingleMsg(client, message); sendErr != nil {
messages[id].sendError = sendErr
errs = append(errs, sendErr)
}