Compare commits

..

1 commit

Author SHA1 Message Date
Michael Fuchs
bf9d55cdd1
Merge 4b60557518 into 077c85bea0 2024-09-26 15:14:50 +00:00
6 changed files with 104 additions and 392 deletions

View file

@ -12,7 +12,6 @@ import (
"net" "net"
"os" "os"
"strings" "strings"
"sync"
"time" "time"
"github.com/wneessen/go-mail/log" "github.com/wneessen/go-mail/log"
@ -88,12 +87,12 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
// Client is the SMTP client struct // Client is the SMTP client struct
type Client struct { type Client struct {
// connection is the net.Conn that the smtp.Client is based on
connection net.Conn
// Timeout for the SMTP server connection // Timeout for the SMTP server connection
connTimeout time.Duration connTimeout time.Duration
// dialContextFunc is a custom DialContext function to dial target SMTP server
dialContextFunc DialContextFunc
// dsn indicates that we want to use DSN for the Client // dsn indicates that we want to use DSN for the Client
dsn bool dsn bool
@ -103,9 +102,11 @@ type Client struct {
// dsnrntype defines the DSNRcptNotifyOption in case DSN is enabled // dsnrntype defines the DSNRcptNotifyOption in case DSN is enabled
dsnrntype []string dsnrntype []string
// fallbackPort is used as an alternative port number in case the primary port is unavailable or // isEncrypted indicates if a Client connection is encrypted or not
// fails to bind. isEncrypted bool
fallbackPort int
// noNoop indicates the Noop is to be skipped
noNoop bool
// HELO/EHLO string for the greeting the target SMTP server // HELO/EHLO string for the greeting the target SMTP server
helo string helo string
@ -113,24 +114,12 @@ type Client struct {
// Hostname of the target SMTP server to connect to // Hostname of the target SMTP server to connect to
host string host string
// isEncrypted indicates if a Client connection is encrypted or not
isEncrypted bool
// logger is a logger that implements the log.Logger interface
logger log.Logger
// mutex is used to synchronize access to shared resources, ensuring that only one goroutine can
// modify them at a time.
mutex sync.RWMutex
// noNoop indicates the Noop is to be skipped
noNoop bool
// pass is the corresponding SMTP AUTH password // pass is the corresponding SMTP AUTH password
pass string pass string
// port specifies the network port number on which the server listens for incoming connections. // Port of the SMTP server to connect to
port int port int
fallbackPort int
// smtpAuth is a pointer to smtp.Auth // smtpAuth is a pointer to smtp.Auth
smtpAuth smtp.Auth smtpAuth smtp.Auth
@ -141,20 +130,26 @@ type Client struct {
// smtpClient is the smtp.Client that is set up when using the Dial*() methods // smtpClient is the smtp.Client that is set up when using the Dial*() methods
smtpClient *smtp.Client smtpClient *smtp.Client
// Use SSL for the connection
useSSL bool
// tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol // tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol
tlspolicy TLSPolicy tlspolicy TLSPolicy
// tlsconfig represents the tls.Config setting for the STARTTLS connection // tlsconfig represents the tls.Config setting for the STARTTLS connection
tlsconfig *tls.Config tlsconfig *tls.Config
// useDebugLog enables the debug logging on the SMTP client
useDebugLog bool
// user is the SMTP AUTH username // user is the SMTP AUTH username
user string user string
// Use SSL for the connection // useDebugLog enables the debug logging on the SMTP client
useSSL bool useDebugLog bool
// logger is a logger that implements the log.Logger interface
logger log.Logger
// dialContextFunc is a custom DialContext function to dial target SMTP server
dialContextFunc DialContextFunc
} }
// Option returns a function that can be used for grouping Client options // Option returns a function that can be used for grouping Client options
@ -555,9 +550,6 @@ func (c *Client) SetLogger(logger log.Logger) {
// SetTLSConfig overrides the current *tls.Config with the given *tls.Config value // SetTLSConfig overrides the current *tls.Config with the given *tls.Config value
func (c *Client) SetTLSConfig(tlsconfig *tls.Config) error { func (c *Client) SetTLSConfig(tlsconfig *tls.Config) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if tlsconfig == nil { if tlsconfig == nil {
return ErrInvalidTLSConfig return ErrInvalidTLSConfig
} }
@ -597,9 +589,6 @@ func (c *Client) setDefaultHelo() error {
// DialWithContext establishes a connection to the SMTP server with a given context.Context // DialWithContext establishes a connection to the SMTP server with a given context.Context
func (c *Client) DialWithContext(dialCtx context.Context) error { func (c *Client) DialWithContext(dialCtx context.Context) error {
c.mutex.Lock()
defer c.mutex.Unlock()
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
defer cancel() defer cancel()
@ -613,16 +602,17 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
c.dialContextFunc = tlsDialer.DialContext c.dialContextFunc = tlsDialer.DialContext
} }
} }
connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) var err error
c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 { if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error? // TODO: should we somehow log or append the previous error?
connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
} }
if err != nil { if err != nil {
return err return err
} }
client, err := smtp.NewClient(connection, c.host) client, err := smtp.NewClient(c.connection, c.host)
if err != nil { if err != nil {
return err return err
} }
@ -701,7 +691,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
// checkConn makes sure that a required server connection is available and extends the // checkConn makes sure that a required server connection is available and extends the
// connection deadline // connection deadline
func (c *Client) checkConn() error { func (c *Client) checkConn() error {
if !c.smtpClient.HasConnection() { if c.connection == nil {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
@ -711,7 +701,7 @@ func (c *Client) checkConn() error {
} }
} }
if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil { if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil {
return ErrDeadlineExtendFailed return ErrDeadlineExtendFailed
} }
return nil return nil
@ -725,7 +715,7 @@ func (c *Client) serverFallbackAddr() string {
// tls tries to make sure that the STARTTLS requirements are satisfied // tls tries to make sure that the STARTTLS requirements are satisfied
func (c *Client) tls() error { func (c *Client) tls() error {
if !c.smtpClient.HasConnection() { if c.connection == nil {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
if !c.useSSL && c.tlspolicy != NoTLS { if !c.useSSL && c.tlspolicy != NoTLS {
@ -801,9 +791,6 @@ func (c *Client) auth() error {
// sendSingleMsg sends out a single message and returns an error if the transmission/delivery fails. // sendSingleMsg sends out a single message and returns an error if the transmission/delivery fails.
// It is invoked by the public Send methods // It is invoked by the public Send methods
func (c *Client) sendSingleMsg(message *Msg) error { func (c *Client) sendSingleMsg(message *Msg) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if message.encoding == NoEncoding { if message.encoding == NoEncoding {
if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok { if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok {
return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message} return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message}

View file

@ -15,7 +15,6 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -624,12 +623,11 @@ func TestClient_DialWithContext(t *testing.T) {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if c.smtpClient == nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
}
if !c.smtpClient.HasConnection() {
t.Errorf("DialWithContext didn't fail but no connection found.")
} }
if err := c.Close(); err != nil { if err := c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err) t.Errorf("failed to close connection: %s", err)
@ -646,18 +644,17 @@ func TestClient_DialWithContext_Fallback(t *testing.T) {
c.SetTLSPortPolicy(TLSOpportunistic) c.SetTLSPortPolicy(TLSOpportunistic)
c.port = 999 c.port = 999
ctx := context.Background() ctx := context.Background()
if err = c.DialWithContext(ctx); err != nil { if err := c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.smtpClient == nil { if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
}
if !c.smtpClient.HasConnection() {
t.Errorf("DialWithContext didn't fail but no connection found.") t.Errorf("DialWithContext didn't fail but no connection found.")
} }
if err = c.Close(); err != nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
}
if err := c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err) t.Errorf("failed to close connection: %s", err)
} }
@ -677,19 +674,18 @@ func TestClient_DialWithContext_Debug(t *testing.T) {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
ctx := context.Background() ctx := context.Background()
if err = c.DialWithContext(ctx); err != nil { if err := c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.smtpClient == nil { if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
}
if !c.smtpClient.HasConnection() {
t.Errorf("DialWithContext didn't fail but no connection found.") t.Errorf("DialWithContext didn't fail but no connection found.")
} }
if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
}
c.SetDebugLog(true) c.SetDebugLog(true)
if err = c.Close(); err != nil { if err := c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err) t.Errorf("failed to close connection: %s", err)
} }
} }
@ -702,20 +698,19 @@ func TestClient_DialWithContext_Debug_custom(t *testing.T) {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
ctx := context.Background() ctx := context.Background()
if err = c.DialWithContext(ctx); err != nil { if err := c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if c.smtpClient == nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
}
if !c.smtpClient.HasConnection() {
t.Errorf("DialWithContext didn't fail but no connection found.")
} }
c.SetDebugLog(true) c.SetDebugLog(true)
c.SetLogger(log.New(os.Stderr, log.LevelDebug)) c.SetLogger(log.New(os.Stderr, log.LevelDebug))
if err = c.Close(); err != nil { if err := c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err) t.Errorf("failed to close connection: %s", err)
} }
} }
@ -727,9 +722,10 @@ func TestClient_DialWithContextInvalidHost(t *testing.T) {
if err != nil { if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
c.connection = nil
c.host = "invalid.addr" c.host = "invalid.addr"
ctx := context.Background() ctx := context.Background()
if err = c.DialWithContext(ctx); err == nil { if err := c.DialWithContext(ctx); err == nil {
t.Errorf("dial succeeded but was supposed to fail") t.Errorf("dial succeeded but was supposed to fail")
return return
} }
@ -742,9 +738,10 @@ func TestClient_DialWithContextInvalidHELO(t *testing.T) {
if err != nil { if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
c.connection = nil
c.helo = "" c.helo = ""
ctx := context.Background() ctx := context.Background()
if err = c.DialWithContext(ctx); err == nil { if err := c.DialWithContext(ctx); err == nil {
t.Errorf("dial succeeded but was supposed to fail") t.Errorf("dial succeeded but was supposed to fail")
return return
} }
@ -761,7 +758,7 @@ func TestClient_DialWithContextInvalidAuth(t *testing.T) {
c.pass = "invalid" c.pass = "invalid"
c.SetSMTPAuthCustom(smtp.LoginAuth("invalid", "invalid", "invalid")) c.SetSMTPAuthCustom(smtp.LoginAuth("invalid", "invalid", "invalid"))
ctx := context.Background() ctx := context.Background()
if err = c.DialWithContext(ctx); err == nil { if err := c.DialWithContext(ctx); err == nil {
t.Errorf("dial succeeded but was supposed to fail") t.Errorf("dial succeeded but was supposed to fail")
return return
} }
@ -773,7 +770,8 @@ func TestClient_checkConn(t *testing.T) {
if err != nil { if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
if err = c.checkConn(); err == nil { c.connection = nil
if err := c.checkConn(); err == nil {
t.Errorf("connCheck() should fail but succeeded") t.Errorf("connCheck() should fail but succeeded")
} }
} }
@ -804,23 +802,21 @@ func TestClient_DialWithContextOptions(t *testing.T) {
} }
ctx := context.Background() ctx := context.Background()
if err = c.DialWithContext(ctx); err != nil && !tt.sf { if err := c.DialWithContext(ctx); err != nil && !tt.sf {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if !tt.sf { if !tt.sf {
if c.connection == nil && !tt.sf {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if c.smtpClient == nil && !tt.sf { if c.smtpClient == nil && !tt.sf {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
return
} }
if !c.smtpClient.HasConnection() && !tt.sf { if err := c.Reset(); err != nil {
t.Errorf("DialWithContext didn't fail but no connection found.")
return
}
if err = c.Reset(); err != nil {
t.Errorf("failed to reset connection: %s", err) t.Errorf("failed to reset connection: %s", err)
} }
if err = c.Close(); err != nil { if err := c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err) t.Errorf("failed to close connection: %s", err)
} }
} }
@ -1015,15 +1011,17 @@ func TestClient_DialSendCloseBroken(t *testing.T) {
} }
if tt.closestart { if tt.closestart {
_ = c.smtpClient.Close() _ = c.smtpClient.Close()
_ = c.connection.Close()
} }
if err = c.Send(m); err != nil && !tt.sf { if err := c.Send(m); err != nil && !tt.sf {
t.Errorf("Send() failed: %s", err) t.Errorf("Send() failed: %s", err)
return return
} }
if tt.closeearly { if tt.closeearly {
_ = c.smtpClient.Close() _ = c.smtpClient.Close()
_ = c.connection.Close()
} }
if err = c.Close(); err != nil && !tt.sf { if err := c.Close(); err != nil && !tt.sf {
t.Errorf("Close() failed: %s", err) t.Errorf("Close() failed: %s", err)
return return
} }
@ -1073,15 +1071,17 @@ func TestClient_DialSendCloseBrokenWithDSN(t *testing.T) {
} }
if tt.closestart { if tt.closestart {
_ = c.smtpClient.Close() _ = c.smtpClient.Close()
_ = c.connection.Close()
} }
if err = c.Send(m); err != nil && !tt.sf { if err := c.Send(m); err != nil && !tt.sf {
t.Errorf("Send() failed: %s", err) t.Errorf("Send() failed: %s", err)
return return
} }
if tt.closeearly { if tt.closeearly {
_ = c.smtpClient.Close() _ = c.smtpClient.Close()
_ = c.connection.Close()
} }
if err = c.Close(); err != nil && !tt.sf { if err := c.Close(); err != nil && !tt.sf {
t.Errorf("Close() failed: %s", err) t.Errorf("Close() failed: %s", err)
return return
} }
@ -1728,114 +1728,6 @@ func TestClient_SendErrorReset(t *testing.T) {
} }
} }
func TestClient_DialSendConcurrent_online(t *testing.T) {
if os.Getenv("TEST_ALLOW_SEND") == "" {
t.Skipf("TEST_ALLOW_SEND is not set. Skipping mail sending test")
}
client, err := getTestConnection(true)
if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err)
}
var messages []*Msg
for i := 0; i < 10; i++ {
message := NewMsg()
if err := message.FromFormat("go-mail Test Mailer", os.Getenv("TEST_FROM")); err != nil {
t.Errorf("failed to set FROM address: %s", err)
return
}
if err := message.To(TestRcpt); err != nil {
t.Errorf("failed to set TO address: %s", err)
return
}
message.Subject(fmt.Sprintf("Test subject for mail %d", i))
message.SetBodyString(TypeTextPlain, fmt.Sprintf("This is the test body of the mail no. %d", i))
message.SetMessageID()
messages = append(messages, message)
}
if err = client.DialWithContext(context.Background()); err != nil {
t.Errorf("failed to dial to test server: %s", err)
}
wg := sync.WaitGroup{}
for id, message := range messages {
wg.Add(1)
go func(curMsg *Msg, curID int) {
defer wg.Done()
if goroutineErr := client.Send(curMsg); err != nil {
t.Errorf("failed to send message with ID %d: %s", curID, goroutineErr)
}
}(message, id)
}
wg.Wait()
if err = client.Close(); err != nil {
t.Errorf("failed to close server connection: %s", err)
}
}
func TestClient_DialSendConcurrent_local(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 20
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 500)
client, err := NewClient(TestServerAddr, WithPort(serverPort),
WithTLSPortPolicy(NoTLS), WithSMTPAuth(SMTPAuthPlain),
WithUsername("toni@tester.com"),
WithPassword("V3ryS3cr3t+"))
if err != nil {
t.Errorf("unable to create new client: %s", err)
}
var messages []*Msg
for i := 0; i < 20; i++ {
message := NewMsg()
if err := message.From("valid-from@domain.tld"); err != nil {
t.Errorf("failed to set FROM address: %s", err)
return
}
if err := message.To("valid-to@domain.tld"); err != nil {
t.Errorf("failed to set TO address: %s", err)
return
}
message.Subject("Test subject")
message.SetBodyString(TypeTextPlain, "Test body")
message.SetMessageIDWithValue("this.is.a.message.id")
messages = append(messages, message)
}
if err = client.DialWithContext(context.Background()); err != nil {
t.Errorf("failed to dial to test server: %s", err)
}
wg := sync.WaitGroup{}
for id, message := range messages {
wg.Add(1)
go func(curMsg *Msg, curID int) {
defer wg.Done()
if goroutineErr := client.Send(curMsg); err != nil {
t.Errorf("failed to send message with ID %d: %s", curID, goroutineErr)
}
}(message, id)
}
wg.Wait()
if err = client.Close(); err != nil {
t.Errorf("failed to close server connection: %s", err)
}
}
// getTestConnection takes environment variables to establish a connection to a real // getTestConnection takes environment variables to establish a connection to a real
// SMTP server to test all functionality that requires a connection // SMTP server to test all functionality that requires a connection
func getTestConnection(auth bool) (*Client, error) { func getTestConnection(auth bool) (*Client, error) {
@ -2021,72 +1913,6 @@ func getTestConnectionWithDSN(auth bool) (*Client, error) {
} }
func TestXOAuth2OK(t *testing.T) { func TestXOAuth2OK(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 30
featureSet := "250-AUTH XOAUTH2\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 500)
c, err := NewClient("127.0.0.1",
WithPort(serverPort),
WithTLSPortPolicy(TLSOpportunistic),
WithSMTPAuth(SMTPAuthXOAUTH2),
WithUsername("user"),
WithPassword("token"))
if err != nil {
t.Fatalf("unable to create new client: %v", err)
}
if err = c.DialWithContext(context.Background()); err != nil {
t.Fatalf("unexpected dial error: %v", err)
}
if err = c.Close(); err != nil {
t.Fatalf("disconnect from test server failed: %v", err)
}
}
func TestXOAuth2Unsupported(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverPort := TestServerPortBase + 31
featureSet := "250-AUTH LOGIN PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
go func() {
if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 500)
c, err := NewClient("127.0.0.1",
WithPort(serverPort),
WithTLSPolicy(TLSOpportunistic),
WithSMTPAuth(SMTPAuthXOAUTH2),
WithUsername("user"),
WithPassword("token"))
if err != nil {
t.Fatalf("unable to create new client: %v", err)
}
if err = c.DialWithContext(context.Background()); err == nil {
t.Fatal("expected dial error got nil")
} else {
if !errors.Is(err, ErrXOauth2AuthNotSupported) {
t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err)
}
}
if err = c.Close(); err != nil {
t.Fatalf("disconnect from test server failed: %v", err)
}
}
func TestXOAuth2OK_faker(t *testing.T) {
server := []string{ server := []string{
"220 Fake server ready ESMTP", "220 Fake server ready ESMTP",
"250-fake.server", "250-fake.server",
@ -2126,7 +1952,7 @@ func TestXOAuth2OK_faker(t *testing.T) {
} }
} }
func TestXOAuth2Unsupported_faker(t *testing.T) { func TestXOAuth2Unsupported(t *testing.T) {
server := []string{ server := []string{
"220 Fake server ready ESMTP", "220 Fake server ready ESMTP",
"250-fake.server", "250-fake.server",
@ -2259,6 +2085,7 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
data, err := reader.ReadString('\n') data, err := reader.ReadString('\n')
if err != nil { if err != nil {
fmt.Printf("unable to read from connection: %s\n", err)
return return
} }
if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") { if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") {
@ -2266,15 +2093,19 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
return return
} }
if err = writeLine("250-localhost.localdomain\r\n" + featureSet); err != nil { if err = writeLine("250-localhost.localdomain\r\n" + featureSet); err != nil {
fmt.Printf("unable to write to connection: %s\n", err)
return return
} }
for { for {
data, err = reader.ReadString('\n') data, err = reader.ReadString('\n')
if err != nil { if err != nil {
if errors.Is(err, io.EOF) {
break
}
fmt.Println("Error reading data:", err)
break break
} }
time.Sleep(time.Millisecond)
var datastring string var datastring string
data = strings.TrimSpace(data) data = strings.TrimSpace(data)
@ -2297,13 +2128,6 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
break break
} }
writeOK() writeOK()
case strings.HasPrefix(data, "AUTH XOAUTH2"):
auth := strings.TrimPrefix(data, "AUTH XOAUTH2 ")
if !strings.EqualFold(auth, "dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=") {
_ = writeLine("535 5.7.8 Error: authentication failed")
break
}
_ = writeLine("235 2.7.0 Authentication successful")
case strings.HasPrefix(data, "AUTH PLAIN"): case strings.HasPrefix(data, "AUTH PLAIN"):
auth := strings.TrimPrefix(data, "AUTH PLAIN ") auth := strings.TrimPrefix(data, "AUTH PLAIN ")
if !strings.EqualFold(auth, "AHRvbmlAdGVzdGVyLmNvbQBWM3J5UzNjcjN0Kw==") { if !strings.EqualFold(auth, "AHRvbmlAdGVzdGVyLmNvbQBWM3J5UzNjcjN0Kw==") {

View file

@ -2,8 +2,8 @@
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
//go:build !go1.20 //go:build go1.19 && !go1.20
// +build !go1.20 // +build go1.19,!go1.20
package mail package mail

View file

@ -30,57 +30,34 @@ import (
"net/textproto" "net/textproto"
"os" "os"
"strings" "strings"
"sync"
"time"
"github.com/wneessen/go-mail/log" "github.com/wneessen/go-mail/log"
) )
// A Client represents a client connection to an SMTP server. // A Client represents a client connection to an SMTP server.
type Client struct { type Client struct {
// Text is the textproto.Conn used by the Client. It is exported to allow for clients to add extensions. // Text is the textproto.Conn used by the Client. It is exported to allow for
// clients to add extensions.
Text *textproto.Conn Text *textproto.Conn
// keep a reference to the connection so it can be used to create a TLS
// auth supported auth mechanisms // connection later
auth []string
// keep a reference to the connection so it can be used to create a TLS connection later
conn net.Conn conn net.Conn
// whether the Client is using TLS
// debug logging is enabled
debug bool
// didHello indicates whether we've said HELO/EHLO
didHello bool
// dsnmrtype defines the mail return option in case DSN is enabled
dsnmrtype string
// dsnrntype defines the recipient notify option in case DSN is enabled
dsnrntype string
// ext is a map of supported extensions
ext map[string]string
// helloError is the error from the hello
helloError error
// localName is the name to use in HELO/EHLO
localName string // the name to use in HELO/EHLO
// logger will be used for debug logging
logger log.Logger
// mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access
// the resource at a time.
mutex sync.RWMutex
// tls indicates whether the Client is using TLS
tls bool tls bool
// serverName denotes the name of the server to which the application will connect. Used for
// identification and routing.
serverName string serverName string
// map of supported extensions
ext map[string]string
// supported auth mechanisms
auth []string
localName string // the name to use in HELO/EHLO
didHello bool // whether we've said HELO/EHLO
helloError error // the error from the hello
// debug logging
debug bool // debug logging is enabled
logger log.Logger // logger will be used for debug logging
// DSN support
dsnmrtype string // dsnmrtype defines the mail return option in case DSN is enabled
dsnrntype string // dsnrntype defines the recipient notify option in case DSN is enabled
} }
// Dial returns a new [Client] connected to an SMTP server at addr. // Dial returns a new [Client] connected to an SMTP server at addr.
@ -117,10 +94,7 @@ func NewClient(conn net.Conn, host string) (*Client, error) {
// Close closes the connection. // Close closes the connection.
func (c *Client) Close() error { func (c *Client) Close() error {
c.mutex.Lock() return c.Text.Close()
err := c.Text.Close()
c.mutex.Unlock()
return err
} }
// hello runs a hello exchange if needed. // hello runs a hello exchange if needed.
@ -147,39 +121,28 @@ func (c *Client) Hello(localName string) error {
if c.didHello { if c.didHello {
return errors.New("smtp: Hello called after other methods") return errors.New("smtp: Hello called after other methods")
} }
c.mutex.Lock()
c.localName = localName c.localName = localName
c.mutex.Unlock()
return c.hello() return c.hello()
} }
// cmd is a convenience function that sends a command and returns the response // cmd is a convenience function that sends a command and returns the response
func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) {
c.mutex.Lock()
c.debugLog(log.DirClientToServer, format, args...) c.debugLog(log.DirClientToServer, format, args...)
id, err := c.Text.Cmd(format, args...) id, err := c.Text.Cmd(format, args...)
if err != nil { if err != nil {
c.mutex.Unlock()
return 0, "", err return 0, "", err
} }
c.Text.StartResponse(id) c.Text.StartResponse(id)
defer c.Text.EndResponse(id)
code, msg, err := c.Text.ReadResponse(expectCode) code, msg, err := c.Text.ReadResponse(expectCode)
c.debugLog(log.DirServerToClient, "%d %s", code, msg) c.debugLog(log.DirServerToClient, "%d %s", code, msg)
c.Text.EndResponse(id)
c.mutex.Unlock()
return code, msg, err return code, msg, err
} }
// helo sends the HELO greeting to the server. It should be used only when the // helo sends the HELO greeting to the server. It should be used only when the
// server does not support ehlo. // server does not support ehlo.
func (c *Client) helo() error { func (c *Client) helo() error {
c.mutex.Lock()
c.ext = nil c.ext = nil
c.mutex.Unlock()
_, _, err := c.cmd(250, "HELO %s", c.localName) _, _, err := c.cmd(250, "HELO %s", c.localName)
return err return err
} }
@ -194,13 +157,9 @@ func (c *Client) StartTLS(config *tls.Config) error {
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock()
c.conn = tls.Client(c.conn, config) c.conn = tls.Client(c.conn, config)
c.Text = textproto.NewConn(c.conn) c.Text = textproto.NewConn(c.conn)
c.tls = true c.tls = true
c.mutex.Unlock()
return c.ehlo() return c.ehlo()
} }
@ -208,15 +167,11 @@ func (c *Client) StartTLS(config *tls.Config) error {
// The return values are their zero values if [Client.StartTLS] did // The return values are their zero values if [Client.StartTLS] did
// not succeed. // not succeed.
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
tc, ok := c.conn.(*tls.Conn) tc, ok := c.conn.(*tls.Conn)
if !ok { if !ok {
return return
} }
state, ok = tc.ConnectionState(), true return tc.ConnectionState(), true
return
} }
// Verify checks the validity of an email address on the server. // Verify checks the validity of an email address on the server.
@ -302,8 +257,6 @@ func (c *Client) Mail(from string) error {
return err return err
} }
cmdStr := "MAIL FROM:<%s>" cmdStr := "MAIL FROM:<%s>"
c.mutex.RLock()
if c.ext != nil { if c.ext != nil {
if _, ok := c.ext["8BITMIME"]; ok { if _, ok := c.ext["8BITMIME"]; ok {
cmdStr += " BODY=8BITMIME" cmdStr += " BODY=8BITMIME"
@ -316,8 +269,6 @@ func (c *Client) Mail(from string) error {
cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype) cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype)
} }
} }
c.mutex.RUnlock()
_, _, err := c.cmd(250, cmdStr, from) _, _, err := c.cmd(250, cmdStr, from)
return err return err
} }
@ -329,11 +280,7 @@ func (c *Client) Rcpt(to string) error {
if err := validateLine(to); err != nil { if err := validateLine(to); err != nil {
return err return err
} }
c.mutex.RLock()
_, ok := c.ext["DSN"] _, ok := c.ext["DSN"]
c.mutex.RUnlock()
if ok && c.dsnrntype != "" { if ok && c.dsnrntype != "" {
_, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype) _, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype)
return err return err
@ -347,23 +294,12 @@ type dataCloser struct {
io.WriteCloser io.WriteCloser
} }
// Close releases the lock, closes the WriteCloser, waits for a response, and then returns any error encountered.
func (d *dataCloser) Close() error { func (d *dataCloser) Close() error {
d.c.mutex.Lock()
_ = d.WriteCloser.Close() _ = d.WriteCloser.Close()
_, _, err := d.c.Text.ReadResponse(250) _, _, err := d.c.Text.ReadResponse(250)
d.c.mutex.Unlock()
return err return err
} }
// Write writes data to the underlying WriteCloser while ensuring thread-safety by locking and unlocking a mutex.
func (d *dataCloser) Write(p []byte) (n int, err error) {
d.c.mutex.Lock()
n, err = d.WriteCloser.Write(p)
d.c.mutex.Unlock()
return
}
// Data issues a DATA command to the server and returns a writer that // Data issues a DATA command to the server and returns a writer that
// can be used to write the mail headers and body. The caller should // can be used to write the mail headers and body. The caller should
// close the writer before calling any more methods on c. A call to // close the writer before calling any more methods on c. A call to
@ -373,14 +309,7 @@ func (c *Client) Data() (io.WriteCloser, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
datacloser := &dataCloser{} return &dataCloser{c, c.Text.DotWriter()}, nil
c.mutex.Lock()
datacloser.c = c
datacloser.WriteCloser = c.Text.DotWriter()
c.mutex.Unlock()
return datacloser, nil
} }
var testHookStartTLS func(*tls.Config) // nil, except for tests var testHookStartTLS func(*tls.Config) // nil, except for tests
@ -476,10 +405,7 @@ func (c *Client) Extension(ext string) (bool, string) {
return false, "" return false, ""
} }
ext = strings.ToUpper(ext) ext = strings.ToUpper(ext)
c.mutex.RLock()
param, ok := c.ext[ext] param, ok := c.ext[ext]
c.mutex.RUnlock()
return ok, param return ok, param
} }
@ -512,11 +438,7 @@ func (c *Client) Quit() error {
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock() return c.Text.Close()
err = c.Text.Close()
c.mutex.Unlock()
return err
} }
// SetDebugLog enables the debug logging for incoming and outgoing SMTP messages // SetDebugLog enables the debug logging for incoming and outgoing SMTP messages
@ -550,21 +472,6 @@ func (c *Client) SetDSNRcptNotifyOption(d string) {
c.dsnrntype = d c.dsnrntype = d
} }
// HasConnection checks if the client has an active connection.
// Returns true if the `conn` field is not nil, indicating an active connection.
func (c *Client) HasConnection() bool {
return c.conn != nil
}
func (c *Client) UpdateDeadline(timeout time.Duration) error {
c.mutex.Lock()
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return fmt.Errorf("smtp: failed to update deadline: %w", err)
}
c.mutex.Unlock()
return nil
}
// debugLog checks if the debug flag is set and if so logs the provided message to // debugLog checks if the debug flag is set and if so logs the provided message to
// the log.Logger interface // the log.Logger interface
func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) { func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) {

View file

@ -25,9 +25,6 @@ func (c *Client) ehlo() error {
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock()
defer c.mutex.Unlock()
ext := make(map[string]string) ext := make(map[string]string)
extList := strings.Split(msg, "\n") extList := strings.Split(msg, "\n")
if len(extList) > 1 { if len(extList) > 1 {

View file

@ -22,15 +22,12 @@ import "strings"
// should be the preferred greeting for servers that support it. // should be the preferred greeting for servers that support it.
// //
// Backport of: https://github.com/golang/go/commit/4d8db00641cc9ff4f44de7df9b8c4f4a4f9416ee#diff-4f6f6bdb9891d4dd271f9f31430420a2e44018fe4ee539576faf458bebb3cee4 // Backport of: https://github.com/golang/go/commit/4d8db00641cc9ff4f44de7df9b8c4f4a4f9416ee#diff-4f6f6bdb9891d4dd271f9f31430420a2e44018fe4ee539576faf458bebb3cee4
// to guarantee backwards compatibility with Go 1.16/1.17 // to guarantee backwards compatibility with Go 1.16/1.17:w
func (c *Client) ehlo() error { func (c *Client) ehlo() error {
_, msg, err := c.cmd(250, "EHLO %s", c.localName) _, msg, err := c.cmd(250, "EHLO %s", c.localName)
if err != nil { if err != nil {
return err return err
} }
c.mutex.Lock()
defer c.mutex.Unlock()
ext := make(map[string]string) ext := make(map[string]string)
extList := strings.Split(msg, "\n") extList := strings.Split(msg, "\n")
if len(extList) > 1 { if len(extList) > 1 {