diff --git a/client.go b/client.go index a7dead8..de6479e 100644 --- a/client.go +++ b/client.go @@ -142,9 +142,6 @@ type ( // host is the hostname of the SMTP server we are connecting to. host string - // isEncrypted indicates wether the Client connection is encrypted or not. - isEncrypted bool - // logAuthData indicates whether authentication-related data should be logged. logAuthData bool @@ -931,62 +928,131 @@ func (c *Client) SetLogAuthData(logAuth bool) { // Returns: // - An error if the connection to the SMTP server fails or any subsequent command fails. func (c *Client) DialWithContext(dialCtx context.Context) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) - defer cancel() - - if c.dialContextFunc == nil { - netDialer := net.Dialer{} - c.dialContextFunc = netDialer.DialContext - - if c.useSSL { - tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig} - c.isEncrypted = true - c.dialContextFunc = tlsDialer.DialContext - } - } - connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) - if err != nil && c.fallbackPort != 0 { - // TODO: should we somehow log or append the previous error? - connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) - } + client, err := c.SMTPClientFromDialWithContext(dialCtx) if err != nil { return err } + c.mutex.Lock() + c.smtpClient = client + c.mutex.Unlock() + /* + ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) + defer cancel() + + isEncrypted := false + if c.dialContextFunc == nil { + netDialer := net.Dialer{} + c.dialContextFunc = netDialer.DialContext + + if c.useSSL { + tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig} + isEncrypted = true + c.dialContextFunc = tlsDialer.DialContext + } + } + connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) + if err != nil && c.fallbackPort != 0 { + // TODO: should we somehow log or append the previous error? + connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) + } + if err != nil { + return err + } + + client, err := smtp.NewClient(connection, c.host) + if err != nil { + return err + } + if client == nil { + return fmt.Errorf("SMTP client is nil") + } + c.smtpClient = client + + if c.logger != nil { + c.smtpClient.SetLogger(c.logger) + } + if c.useDebugLog { + c.smtpClient.SetDebugLog(true) + } + if c.logAuthData { + c.smtpClient.SetLogAuthData() + } + if err = c.smtpClient.Hello(c.helo); err != nil { + return err + } + + if err = c.tls(c.smtpClient, &isEncrypted); err != nil { + return err + } + + if err = c.auth(c.smtpClient, isEncrypted); err != nil { + return err + } + + */ + + return nil +} + +// SMTPClientFromDialWithContext is similar to DialWithContext but instead of storing the smtp.Client +// on the Client it will return the smtp.Client instead. +func (c *Client) SMTPClientFromDialWithContext(ctxDial context.Context) (*smtp.Client, error) { + c.mutex.RLock() + defer c.mutex.RUnlock() + + ctx, cancel := context.WithDeadline(ctxDial, time.Now().Add(c.connTimeout)) + defer cancel() + + isEncrypted := false + dialContextFunc := c.dialContextFunc + if c.dialContextFunc == nil { + netDialer := net.Dialer{} + dialContextFunc = netDialer.DialContext + if c.useSSL { + tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig} + isEncrypted = true + dialContextFunc = tlsDialer.DialContext + } + } + connection, err := dialContextFunc(ctx, "tcp", c.ServerAddr()) + if err != nil && c.fallbackPort != 0 { + // TODO: should we somehow log or append the previous error? + connection, err = dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) + } + if err != nil { + return nil, err + } client, err := smtp.NewClient(connection, c.host) if err != nil { - return err + return nil, err } if client == nil { - return fmt.Errorf("SMTP client is nil") + return nil, fmt.Errorf("SMTP client is nil") } - c.smtpClient = client if c.logger != nil { - c.smtpClient.SetLogger(c.logger) + client.SetLogger(c.logger) } if c.useDebugLog { - c.smtpClient.SetDebugLog(true) + client.SetDebugLog(true) } if c.logAuthData { - c.smtpClient.SetLogAuthData() + client.SetLogAuthData() } - if err = c.smtpClient.Hello(c.helo); err != nil { - return err + if err = client.Hello(c.helo); err != nil { + return nil, err } - if err = c.tls(); err != nil { - return err + if err = c.tls(client, &isEncrypted); err != nil { + return nil, err } - if err = c.auth(); err != nil { - return err + if err = c.auth(client, isEncrypted); err != nil { + return nil, err } - return nil + return client, nil } // Close terminates the connection to the SMTP server, returning an error if the disconnection @@ -999,10 +1065,14 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { // Returns: // - An error if the disconnection fails; otherwise, returns nil. func (c *Client) Close() error { - if c.smtpClient == nil || !c.smtpClient.HasConnection() { + return c.CloseWithSMTPClient(c.smtpClient) +} + +func (c *Client) CloseWithSMTPClient(client *smtp.Client) error { + if client == nil || !client.HasConnection() { return nil } - if err := c.smtpClient.Quit(); err != nil { + if err := client.Quit(); err != nil { return fmt.Errorf("failed to close SMTP client: %w", err) } @@ -1018,10 +1088,14 @@ func (c *Client) Close() error { // Returns: // - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil. func (c *Client) Reset() error { - if err := c.checkConn(); err != nil { + return c.ResetWithSMTPClient(c.smtpClient) +} + +func (c *Client) ResetWithSMTPClient(client *smtp.Client) error { + if err := c.checkConn(client); err != nil { return err } - if err := c.smtpClient.Reset(); err != nil { + if err := client.Reset(); err != nil { return fmt.Errorf("failed to send RSET to SMTP client: %w", err) } @@ -1061,19 +1135,20 @@ func (c *Client) DialAndSend(messages ...*Msg) error { // - An error if the connection fails, if sending the messages fails, or if closing the // connection fails; otherwise, returns nil. func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error { - c.sendMutex.Lock() - defer c.sendMutex.Unlock() - if err := c.DialWithContext(ctx); err != nil { + //c.sendMutex.Lock() + //defer c.sendMutex.Unlock() + client, err := c.SMTPClientFromDialWithContext(ctx) + if err != nil { return fmt.Errorf("dial failed: %w", err) } defer func() { - _ = c.Close() + _ = c.CloseWithSMTPClient(client) }() - if err := c.Send(messages...); err != nil { + if err := c.SendWithSMTPClient(client, messages...); err != nil { return fmt.Errorf("send failed: %w", err) } - if err := c.Close(); err != nil { + if err := c.CloseWithSMTPClient(client); err != nil { return fmt.Errorf("failed to close connection: %w", err) } return nil @@ -1098,16 +1173,17 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e // Returns: // - An error if the connection check fails, if no supported authentication method is found, // or if the authentication process fails. -func (c *Client) auth() error { +func (c *Client) auth(client *smtp.Client, isEnc bool) error { + var smtpAuth smtp.Auth if c.smtpAuth == nil && c.smtpAuthType != SMTPAuthNoAuth { - hasSMTPAuth, smtpAuthType := c.smtpClient.Extension("AUTH") + hasSMTPAuth, smtpAuthType := client.Extension("AUTH") if !hasSMTPAuth { return fmt.Errorf("server does not support SMTP AUTH") } authType := c.smtpAuthType if c.smtpAuthType == SMTPAuthAutoDiscover { - discoveredType, err := c.authTypeAutoDiscover(smtpAuthType) + discoveredType, err := c.authTypeAutoDiscover(smtpAuthType, isEnc) if err != nil { return err } @@ -1119,74 +1195,74 @@ func (c *Client) auth() error { if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) { return ErrPlainAuthNotSupported } - c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false) + smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false) case SMTPAuthPlainNoEnc: if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) { return ErrPlainAuthNotSupported } - c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true) + smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true) case SMTPAuthLogin: if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) { return ErrLoginAuthNotSupported } - c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false) + smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false) case SMTPAuthLoginNoEnc: if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) { return ErrLoginAuthNotSupported } - c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true) + smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true) case SMTPAuthCramMD5: if !strings.Contains(smtpAuthType, string(SMTPAuthCramMD5)) { return ErrCramMD5AuthNotSupported } - c.smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass) + smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass) case SMTPAuthXOAUTH2: if !strings.Contains(smtpAuthType, string(SMTPAuthXOAUTH2)) { return ErrXOauth2AuthNotSupported } - c.smtpAuth = smtp.XOAuth2Auth(c.user, c.pass) + smtpAuth = smtp.XOAuth2Auth(c.user, c.pass) case SMTPAuthSCRAMSHA1: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1)) { return ErrSCRAMSHA1AuthNotSupported } - c.smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass) + smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass) case SMTPAuthSCRAMSHA256: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256)) { return ErrSCRAMSHA256AuthNotSupported } - c.smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass) + smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass) case SMTPAuthSCRAMSHA1PLUS: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1PLUS)) { return ErrSCRAMSHA1PLUSAuthNotSupported } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { return err } - c.smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState) + smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState) case SMTPAuthSCRAMSHA256PLUS: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256PLUS)) { return ErrSCRAMSHA256PLUSAuthNotSupported } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { return err } - c.smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState) + smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState) default: return fmt.Errorf("unsupported SMTP AUTH type %q", c.smtpAuthType) } } - if c.smtpAuth != nil { - if err := c.smtpClient.Auth(c.smtpAuth); err != nil { + if smtpAuth != nil { + if err := client.Auth(smtpAuth); err != nil { return fmt.Errorf("SMTP AUTH failed: %w", err) } } return nil } -func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) { +func (c *Client) authTypeAutoDiscover(supported string, isEnc bool) (SMTPAuthType, error) { if supported == "" { return "", ErrNoSupportedAuthDiscovered } @@ -1194,7 +1270,7 @@ func (c *Client) authTypeAutoDiscover(supported string) (SMTPAuthType, error) { SMTPAuthSCRAMSHA256PLUS, SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1PLUS, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5, SMTPAuthPlain, SMTPAuthLogin, } - if !c.isEncrypted { + if !isEnc { preferList = []SMTPAuthType{SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5} } mechs := strings.Split(supported, " ") @@ -1231,13 +1307,13 @@ func sliceContains(slice []string, item string) bool { // // Returns: // - An error if any part of the sending process fails; otherwise, returns nil. -func (c *Client) sendSingleMsg(message *Msg) error { - c.mutex.Lock() - defer c.mutex.Unlock() - escSupport, _ := c.smtpClient.Extension("ENHANCEDSTATUSCODES") +func (c *Client) sendSingleMsg(client *smtp.Client, message *Msg) error { + c.mutex.RLock() + defer c.mutex.RUnlock() + escSupport, _ := client.Extension("ENHANCEDSTATUSCODES") if message.encoding == NoEncoding { - if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok { + if ok, _ := client.Extension("8BITMIME"); !ok { return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message} } } @@ -1260,16 +1336,16 @@ func (c *Client) sendSingleMsg(message *Msg) error { if c.requestDSN { if c.dsnReturnType != "" { - c.smtpClient.SetDSNMailReturnOption(string(c.dsnReturnType)) + client.SetDSNMailReturnOption(string(c.dsnReturnType)) } } - if err = c.smtpClient.Mail(from); err != nil { + if err = client.Mail(from); err != nil { retError := &SendError{ Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err), affectedMsg: message, errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), } - if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil { + if resetSendErr := client.Reset(); resetSendErr != nil { retError.errlist = append(retError.errlist, resetSendErr) } return retError @@ -1279,9 +1355,9 @@ func (c *Client) sendSingleMsg(message *Msg) error { rcptSendErr.errlist = make([]error, 0) rcptSendErr.rcpt = make([]string, 0) rcptNotifyOpt := strings.Join(c.dsnRcptNotifyType, ",") - c.smtpClient.SetDSNRcptNotifyOption(rcptNotifyOpt) + client.SetDSNRcptNotifyOption(rcptNotifyOpt) for _, rcpt := range rcpts { - if err = c.smtpClient.Rcpt(rcpt); err != nil { + if err = client.Rcpt(rcpt); err != nil { rcptSendErr.Reason = ErrSMTPRcptTo rcptSendErr.errlist = append(rcptSendErr.errlist, err) rcptSendErr.rcpt = append(rcptSendErr.rcpt, rcpt) @@ -1292,12 +1368,12 @@ func (c *Client) sendSingleMsg(message *Msg) error { } } if hasError { - if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil { + if resetSendErr := client.Reset(); resetSendErr != nil { rcptSendErr.errlist = append(rcptSendErr.errlist, resetSendErr) } return rcptSendErr } - writer, err := c.smtpClient.Data() + writer, err := client.Data() if err != nil { return &SendError{ Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err), @@ -1322,7 +1398,7 @@ func (c *Client) sendSingleMsg(message *Msg) error { } message.isDelivered = true - if err = c.Reset(); err != nil { + if err = c.ResetWithSMTPClient(client); err != nil { return &SendError{ Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err), affectedMsg: message, errcode: errorCode(err), @@ -1344,21 +1420,24 @@ func (c *Client) sendSingleMsg(message *Msg) error { // Returns: // - An error if there is no active connection, if the NOOP command fails, or if extending // the deadline fails; otherwise, returns nil. -func (c *Client) checkConn() error { - if c.smtpClient == nil { +func (c *Client) checkConn(client *smtp.Client) error { + if client == nil { return ErrNoActiveConnection } - if !c.smtpClient.HasConnection() { + if !client.HasConnection() { return ErrNoActiveConnection } - if !c.noNoop { - if err := c.smtpClient.Noop(); err != nil { + c.mutex.RLock() + noNoop := c.noNoop + c.mutex.RUnlock() + if !noNoop { + if err := client.Noop(); err != nil { return ErrNoActiveConnection } } - if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil { + if err := client.UpdateDeadline(c.connTimeout); err != nil { return ErrDeadlineExtendFailed } return nil @@ -1405,10 +1484,10 @@ func (c *Client) setDefaultHelo() error { // Returns: // - An error if there is no active connection, if STARTTLS is required but not supported, // or if there are issues during the TLS handshake; otherwise, returns nil. -func (c *Client) tls() error { +func (c *Client) tls(client *smtp.Client, isEnc *bool) error { if !c.useSSL && c.tlspolicy != NoTLS { hasStartTLS := false - extension, _ := c.smtpClient.Extension("STARTTLS") + extension, _ := client.Extension("STARTTLS") if c.tlspolicy == TLSMandatory { hasStartTLS = true if !extension { @@ -1422,21 +1501,21 @@ func (c *Client) tls() error { } } if hasStartTLS { - if err := c.smtpClient.StartTLS(c.tlsconfig); err != nil { + if err := client.StartTLS(c.tlsconfig); err != nil { return err } } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { switch { case errors.Is(err, smtp.ErrNonTLSConnection): - c.isEncrypted = false + *isEnc = false return nil default: return fmt.Errorf("failed to get TLS connection state: %w", err) } } - c.isEncrypted = tlsConnState.HandshakeComplete + *isEnc = tlsConnState.HandshakeComplete } return nil } diff --git a/client_120.go b/client_120.go index 67c5b5e..16e2d0b 100644 --- a/client_120.go +++ b/client_120.go @@ -9,6 +9,8 @@ package mail import ( "errors" + + "github.com/wneessen/go-mail/smtp" ) // Send attempts to send one or more Msg using the Client connection to the SMTP server. @@ -27,11 +29,15 @@ import ( // Returns: // - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil. func (c *Client) Send(messages ...*Msg) (returnErr error) { + return c.SendWithSMTPClient(c.smtpClient, messages...) +} + +func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) (returnErr error) { escSupport := false - if c.smtpClient != nil { - escSupport, _ = c.smtpClient.Extension("ENHANCEDSTATUSCODES") + if client != nil { + escSupport, _ = client.Extension("ENHANCEDSTATUSCODES") } - if err := c.checkConn(); err != nil { + if err := c.checkConn(client); err != nil { returnErr = &SendError{ Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err), errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), @@ -45,7 +51,7 @@ func (c *Client) Send(messages ...*Msg) (returnErr error) { }() for id, message := range messages { - if sendErr := c.sendSingleMsg(message); sendErr != nil { + if sendErr := c.sendSingleMsg(client, message); sendErr != nil { messages[id].sendError = sendErr errs = append(errs, sendErr) }