mirror of
https://github.com/wneessen/go-mail.git
synced 2024-11-15 02:12:55 +01:00
Merge pull request #307 from wneessen/feature/269_goroutineconcurrency-safety
go-mail goroutine-/thread-safety
This commit is contained in:
commit
65a91a2711
6 changed files with 389 additions and 101 deletions
69
client.go
69
client.go
|
@ -12,6 +12,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/wneessen/go-mail/log"
|
||||
|
@ -87,12 +88,12 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
|
|||
|
||||
// Client is the SMTP client struct
|
||||
type Client struct {
|
||||
// connection is the net.Conn that the smtp.Client is based on
|
||||
connection net.Conn
|
||||
|
||||
// Timeout for the SMTP server connection
|
||||
connTimeout time.Duration
|
||||
|
||||
// dialContextFunc is a custom DialContext function to dial target SMTP server
|
||||
dialContextFunc DialContextFunc
|
||||
|
||||
// dsn indicates that we want to use DSN for the Client
|
||||
dsn bool
|
||||
|
||||
|
@ -102,11 +103,9 @@ type Client struct {
|
|||
// dsnrntype defines the DSNRcptNotifyOption in case DSN is enabled
|
||||
dsnrntype []string
|
||||
|
||||
// isEncrypted indicates if a Client connection is encrypted or not
|
||||
isEncrypted bool
|
||||
|
||||
// noNoop indicates the Noop is to be skipped
|
||||
noNoop bool
|
||||
// fallbackPort is used as an alternative port number in case the primary port is unavailable or
|
||||
// fails to bind.
|
||||
fallbackPort int
|
||||
|
||||
// HELO/EHLO string for the greeting the target SMTP server
|
||||
helo string
|
||||
|
@ -114,12 +113,24 @@ type Client struct {
|
|||
// Hostname of the target SMTP server to connect to
|
||||
host string
|
||||
|
||||
// isEncrypted indicates if a Client connection is encrypted or not
|
||||
isEncrypted bool
|
||||
|
||||
// logger is a logger that implements the log.Logger interface
|
||||
logger log.Logger
|
||||
|
||||
// mutex is used to synchronize access to shared resources, ensuring that only one goroutine can
|
||||
// modify them at a time.
|
||||
mutex sync.RWMutex
|
||||
|
||||
// noNoop indicates the Noop is to be skipped
|
||||
noNoop bool
|
||||
|
||||
// pass is the corresponding SMTP AUTH password
|
||||
pass string
|
||||
|
||||
// Port of the SMTP server to connect to
|
||||
port int
|
||||
fallbackPort int
|
||||
// port specifies the network port number on which the server listens for incoming connections.
|
||||
port int
|
||||
|
||||
// smtpAuth is a pointer to smtp.Auth
|
||||
smtpAuth smtp.Auth
|
||||
|
@ -130,26 +141,20 @@ type Client struct {
|
|||
// smtpClient is the smtp.Client that is set up when using the Dial*() methods
|
||||
smtpClient *smtp.Client
|
||||
|
||||
// Use SSL for the connection
|
||||
useSSL bool
|
||||
|
||||
// tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol
|
||||
tlspolicy TLSPolicy
|
||||
|
||||
// tlsconfig represents the tls.Config setting for the STARTTLS connection
|
||||
tlsconfig *tls.Config
|
||||
|
||||
// user is the SMTP AUTH username
|
||||
user string
|
||||
|
||||
// useDebugLog enables the debug logging on the SMTP client
|
||||
useDebugLog bool
|
||||
|
||||
// logger is a logger that implements the log.Logger interface
|
||||
logger log.Logger
|
||||
// user is the SMTP AUTH username
|
||||
user string
|
||||
|
||||
// dialContextFunc is a custom DialContext function to dial target SMTP server
|
||||
dialContextFunc DialContextFunc
|
||||
// Use SSL for the connection
|
||||
useSSL bool
|
||||
}
|
||||
|
||||
// Option returns a function that can be used for grouping Client options
|
||||
|
@ -550,6 +555,9 @@ func (c *Client) SetLogger(logger log.Logger) {
|
|||
|
||||
// SetTLSConfig overrides the current *tls.Config with the given *tls.Config value
|
||||
func (c *Client) SetTLSConfig(tlsconfig *tls.Config) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if tlsconfig == nil {
|
||||
return ErrInvalidTLSConfig
|
||||
}
|
||||
|
@ -589,6 +597,9 @@ func (c *Client) setDefaultHelo() error {
|
|||
|
||||
// DialWithContext establishes a connection to the SMTP server with a given context.Context
|
||||
func (c *Client) DialWithContext(dialCtx context.Context) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
|
||||
defer cancel()
|
||||
|
||||
|
@ -602,17 +613,16 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
|
|||
c.dialContextFunc = tlsDialer.DialContext
|
||||
}
|
||||
}
|
||||
var err error
|
||||
c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
|
||||
connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr())
|
||||
if err != nil && c.fallbackPort != 0 {
|
||||
// TODO: should we somehow log or append the previous error?
|
||||
c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
|
||||
connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := smtp.NewClient(c.connection, c.host)
|
||||
client, err := smtp.NewClient(connection, c.host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -691,7 +701,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
|
|||
// checkConn makes sure that a required server connection is available and extends the
|
||||
// connection deadline
|
||||
func (c *Client) checkConn() error {
|
||||
if c.connection == nil {
|
||||
if !c.smtpClient.HasConnection() {
|
||||
return ErrNoActiveConnection
|
||||
}
|
||||
|
||||
|
@ -701,7 +711,7 @@ func (c *Client) checkConn() error {
|
|||
}
|
||||
}
|
||||
|
||||
if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil {
|
||||
if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil {
|
||||
return ErrDeadlineExtendFailed
|
||||
}
|
||||
return nil
|
||||
|
@ -715,7 +725,7 @@ func (c *Client) serverFallbackAddr() string {
|
|||
|
||||
// tls tries to make sure that the STARTTLS requirements are satisfied
|
||||
func (c *Client) tls() error {
|
||||
if c.connection == nil {
|
||||
if !c.smtpClient.HasConnection() {
|
||||
return ErrNoActiveConnection
|
||||
}
|
||||
if !c.useSSL && c.tlspolicy != NoTLS {
|
||||
|
@ -791,6 +801,9 @@ func (c *Client) auth() error {
|
|||
// sendSingleMsg sends out a single message and returns an error if the transmission/delivery fails.
|
||||
// It is invoked by the public Send methods
|
||||
func (c *Client) sendSingleMsg(message *Msg) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if message.encoding == NoEncoding {
|
||||
if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok {
|
||||
return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message}
|
||||
|
|
268
client_test.go
268
client_test.go
|
@ -15,6 +15,7 @@ import (
|
|||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -623,11 +624,12 @@ func TestClient_DialWithContext(t *testing.T) {
|
|||
t.Errorf("failed to dial with context: %s", err)
|
||||
return
|
||||
}
|
||||
if c.connection == nil {
|
||||
t.Errorf("DialWithContext didn't fail but no connection found.")
|
||||
}
|
||||
if c.smtpClient == nil {
|
||||
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
|
||||
return
|
||||
}
|
||||
if !c.smtpClient.HasConnection() {
|
||||
t.Errorf("DialWithContext didn't fail but no connection found.")
|
||||
}
|
||||
if err := c.Close(); err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
|
@ -644,17 +646,18 @@ func TestClient_DialWithContext_Fallback(t *testing.T) {
|
|||
c.SetTLSPortPolicy(TLSOpportunistic)
|
||||
c.port = 999
|
||||
ctx := context.Background()
|
||||
if err := c.DialWithContext(ctx); err != nil {
|
||||
if err = c.DialWithContext(ctx); err != nil {
|
||||
t.Errorf("failed to dial with context: %s", err)
|
||||
return
|
||||
}
|
||||
if c.connection == nil {
|
||||
t.Errorf("DialWithContext didn't fail but no connection found.")
|
||||
}
|
||||
if c.smtpClient == nil {
|
||||
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
|
||||
return
|
||||
}
|
||||
if err := c.Close(); err != nil {
|
||||
if !c.smtpClient.HasConnection() {
|
||||
t.Errorf("DialWithContext didn't fail but no connection found.")
|
||||
}
|
||||
if err = c.Close(); err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
|
@ -674,18 +677,19 @@ func TestClient_DialWithContext_Debug(t *testing.T) {
|
|||
t.Skipf("failed to create test client: %s. Skipping tests", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
if err := c.DialWithContext(ctx); err != nil {
|
||||
if err = c.DialWithContext(ctx); err != nil {
|
||||
t.Errorf("failed to dial with context: %s", err)
|
||||
return
|
||||
}
|
||||
if c.connection == nil {
|
||||
t.Errorf("DialWithContext didn't fail but no connection found.")
|
||||
}
|
||||
if c.smtpClient == nil {
|
||||
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
|
||||
return
|
||||
}
|
||||
if !c.smtpClient.HasConnection() {
|
||||
t.Errorf("DialWithContext didn't fail but no connection found.")
|
||||
}
|
||||
c.SetDebugLog(true)
|
||||
if err := c.Close(); err != nil {
|
||||
if err = c.Close(); err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -698,19 +702,20 @@ func TestClient_DialWithContext_Debug_custom(t *testing.T) {
|
|||
t.Skipf("failed to create test client: %s. Skipping tests", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
if err := c.DialWithContext(ctx); err != nil {
|
||||
if err = c.DialWithContext(ctx); err != nil {
|
||||
t.Errorf("failed to dial with context: %s", err)
|
||||
return
|
||||
}
|
||||
if c.connection == nil {
|
||||
t.Errorf("DialWithContext didn't fail but no connection found.")
|
||||
}
|
||||
if c.smtpClient == nil {
|
||||
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
|
||||
return
|
||||
}
|
||||
if !c.smtpClient.HasConnection() {
|
||||
t.Errorf("DialWithContext didn't fail but no connection found.")
|
||||
}
|
||||
c.SetDebugLog(true)
|
||||
c.SetLogger(log.New(os.Stderr, log.LevelDebug))
|
||||
if err := c.Close(); err != nil {
|
||||
if err = c.Close(); err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -722,10 +727,9 @@ func TestClient_DialWithContextInvalidHost(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Skipf("failed to create test client: %s. Skipping tests", err)
|
||||
}
|
||||
c.connection = nil
|
||||
c.host = "invalid.addr"
|
||||
ctx := context.Background()
|
||||
if err := c.DialWithContext(ctx); err == nil {
|
||||
if err = c.DialWithContext(ctx); err == nil {
|
||||
t.Errorf("dial succeeded but was supposed to fail")
|
||||
return
|
||||
}
|
||||
|
@ -738,10 +742,9 @@ func TestClient_DialWithContextInvalidHELO(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Skipf("failed to create test client: %s. Skipping tests", err)
|
||||
}
|
||||
c.connection = nil
|
||||
c.helo = ""
|
||||
ctx := context.Background()
|
||||
if err := c.DialWithContext(ctx); err == nil {
|
||||
if err = c.DialWithContext(ctx); err == nil {
|
||||
t.Errorf("dial succeeded but was supposed to fail")
|
||||
return
|
||||
}
|
||||
|
@ -758,7 +761,7 @@ func TestClient_DialWithContextInvalidAuth(t *testing.T) {
|
|||
c.pass = "invalid"
|
||||
c.SetSMTPAuthCustom(smtp.LoginAuth("invalid", "invalid", "invalid"))
|
||||
ctx := context.Background()
|
||||
if err := c.DialWithContext(ctx); err == nil {
|
||||
if err = c.DialWithContext(ctx); err == nil {
|
||||
t.Errorf("dial succeeded but was supposed to fail")
|
||||
return
|
||||
}
|
||||
|
@ -770,8 +773,7 @@ func TestClient_checkConn(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Skipf("failed to create test client: %s. Skipping tests", err)
|
||||
}
|
||||
c.connection = nil
|
||||
if err := c.checkConn(); err == nil {
|
||||
if err = c.checkConn(); err == nil {
|
||||
t.Errorf("connCheck() should fail but succeeded")
|
||||
}
|
||||
}
|
||||
|
@ -802,21 +804,23 @@ func TestClient_DialWithContextOptions(t *testing.T) {
|
|||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := c.DialWithContext(ctx); err != nil && !tt.sf {
|
||||
if err = c.DialWithContext(ctx); err != nil && !tt.sf {
|
||||
t.Errorf("failed to dial with context: %s", err)
|
||||
return
|
||||
}
|
||||
if !tt.sf {
|
||||
if c.connection == nil && !tt.sf {
|
||||
t.Errorf("DialWithContext didn't fail but no connection found.")
|
||||
}
|
||||
if c.smtpClient == nil && !tt.sf {
|
||||
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
|
||||
return
|
||||
}
|
||||
if err := c.Reset(); err != nil {
|
||||
if !c.smtpClient.HasConnection() && !tt.sf {
|
||||
t.Errorf("DialWithContext didn't fail but no connection found.")
|
||||
return
|
||||
}
|
||||
if err = c.Reset(); err != nil {
|
||||
t.Errorf("failed to reset connection: %s", err)
|
||||
}
|
||||
if err := c.Close(); err != nil {
|
||||
if err = c.Close(); err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -1011,17 +1015,15 @@ func TestClient_DialSendCloseBroken(t *testing.T) {
|
|||
}
|
||||
if tt.closestart {
|
||||
_ = c.smtpClient.Close()
|
||||
_ = c.connection.Close()
|
||||
}
|
||||
if err := c.Send(m); err != nil && !tt.sf {
|
||||
if err = c.Send(m); err != nil && !tt.sf {
|
||||
t.Errorf("Send() failed: %s", err)
|
||||
return
|
||||
}
|
||||
if tt.closeearly {
|
||||
_ = c.smtpClient.Close()
|
||||
_ = c.connection.Close()
|
||||
}
|
||||
if err := c.Close(); err != nil && !tt.sf {
|
||||
if err = c.Close(); err != nil && !tt.sf {
|
||||
t.Errorf("Close() failed: %s", err)
|
||||
return
|
||||
}
|
||||
|
@ -1071,17 +1073,15 @@ func TestClient_DialSendCloseBrokenWithDSN(t *testing.T) {
|
|||
}
|
||||
if tt.closestart {
|
||||
_ = c.smtpClient.Close()
|
||||
_ = c.connection.Close()
|
||||
}
|
||||
if err := c.Send(m); err != nil && !tt.sf {
|
||||
if err = c.Send(m); err != nil && !tt.sf {
|
||||
t.Errorf("Send() failed: %s", err)
|
||||
return
|
||||
}
|
||||
if tt.closeearly {
|
||||
_ = c.smtpClient.Close()
|
||||
_ = c.connection.Close()
|
||||
}
|
||||
if err := c.Close(); err != nil && !tt.sf {
|
||||
if err = c.Close(); err != nil && !tt.sf {
|
||||
t.Errorf("Close() failed: %s", err)
|
||||
return
|
||||
}
|
||||
|
@ -1728,6 +1728,114 @@ func TestClient_SendErrorReset(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestClient_DialSendConcurrent_online(t *testing.T) {
|
||||
if os.Getenv("TEST_ALLOW_SEND") == "" {
|
||||
t.Skipf("TEST_ALLOW_SEND is not set. Skipping mail sending test")
|
||||
}
|
||||
|
||||
client, err := getTestConnection(true)
|
||||
if err != nil {
|
||||
t.Skipf("failed to create test client: %s. Skipping tests", err)
|
||||
}
|
||||
|
||||
var messages []*Msg
|
||||
for i := 0; i < 10; i++ {
|
||||
message := NewMsg()
|
||||
if err := message.FromFormat("go-mail Test Mailer", os.Getenv("TEST_FROM")); err != nil {
|
||||
t.Errorf("failed to set FROM address: %s", err)
|
||||
return
|
||||
}
|
||||
if err := message.To(TestRcpt); err != nil {
|
||||
t.Errorf("failed to set TO address: %s", err)
|
||||
return
|
||||
}
|
||||
message.Subject(fmt.Sprintf("Test subject for mail %d", i))
|
||||
message.SetBodyString(TypeTextPlain, fmt.Sprintf("This is the test body of the mail no. %d", i))
|
||||
message.SetMessageID()
|
||||
messages = append(messages, message)
|
||||
}
|
||||
|
||||
if err = client.DialWithContext(context.Background()); err != nil {
|
||||
t.Errorf("failed to dial to test server: %s", err)
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for id, message := range messages {
|
||||
wg.Add(1)
|
||||
go func(curMsg *Msg, curID int) {
|
||||
defer wg.Done()
|
||||
if goroutineErr := client.Send(curMsg); err != nil {
|
||||
t.Errorf("failed to send message with ID %d: %s", curID, goroutineErr)
|
||||
}
|
||||
}(message, id)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if err = client.Close(); err != nil {
|
||||
t.Errorf("failed to close server connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_DialSendConcurrent_local(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
serverPort := TestServerPortBase + 20
|
||||
featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
||||
go func() {
|
||||
if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil {
|
||||
t.Errorf("failed to start test server: %s", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
|
||||
client, err := NewClient(TestServerAddr, WithPort(serverPort),
|
||||
WithTLSPortPolicy(NoTLS), WithSMTPAuth(SMTPAuthPlain),
|
||||
WithUsername("toni@tester.com"),
|
||||
WithPassword("V3ryS3cr3t+"))
|
||||
if err != nil {
|
||||
t.Errorf("unable to create new client: %s", err)
|
||||
}
|
||||
|
||||
var messages []*Msg
|
||||
for i := 0; i < 20; i++ {
|
||||
message := NewMsg()
|
||||
if err := message.From("valid-from@domain.tld"); err != nil {
|
||||
t.Errorf("failed to set FROM address: %s", err)
|
||||
return
|
||||
}
|
||||
if err := message.To("valid-to@domain.tld"); err != nil {
|
||||
t.Errorf("failed to set TO address: %s", err)
|
||||
return
|
||||
}
|
||||
message.Subject("Test subject")
|
||||
message.SetBodyString(TypeTextPlain, "Test body")
|
||||
message.SetMessageIDWithValue("this.is.a.message.id")
|
||||
messages = append(messages, message)
|
||||
}
|
||||
|
||||
if err = client.DialWithContext(context.Background()); err != nil {
|
||||
t.Errorf("failed to dial to test server: %s", err)
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for id, message := range messages {
|
||||
wg.Add(1)
|
||||
go func(curMsg *Msg, curID int) {
|
||||
defer wg.Done()
|
||||
if goroutineErr := client.Send(curMsg); err != nil {
|
||||
t.Errorf("failed to send message with ID %d: %s", curID, goroutineErr)
|
||||
}
|
||||
}(message, id)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if err = client.Close(); err != nil {
|
||||
t.Errorf("failed to close server connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// getTestConnection takes environment variables to establish a connection to a real
|
||||
// SMTP server to test all functionality that requires a connection
|
||||
func getTestConnection(auth bool) (*Client, error) {
|
||||
|
@ -1913,6 +2021,72 @@ func getTestConnectionWithDSN(auth bool) (*Client, error) {
|
|||
}
|
||||
|
||||
func TestXOAuth2OK(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
serverPort := TestServerPortBase + 30
|
||||
featureSet := "250-AUTH XOAUTH2\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
||||
go func() {
|
||||
if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil {
|
||||
t.Errorf("failed to start test server: %s", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
|
||||
c, err := NewClient("127.0.0.1",
|
||||
WithPort(serverPort),
|
||||
WithTLSPortPolicy(TLSOpportunistic),
|
||||
WithSMTPAuth(SMTPAuthXOAUTH2),
|
||||
WithUsername("user"),
|
||||
WithPassword("token"))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create new client: %v", err)
|
||||
}
|
||||
if err = c.DialWithContext(context.Background()); err != nil {
|
||||
t.Fatalf("unexpected dial error: %v", err)
|
||||
}
|
||||
if err = c.Close(); err != nil {
|
||||
t.Fatalf("disconnect from test server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXOAuth2Unsupported(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
serverPort := TestServerPortBase + 31
|
||||
featureSet := "250-AUTH LOGIN PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
|
||||
go func() {
|
||||
if err := simpleSMTPServer(ctx, featureSet, false, serverPort); err != nil {
|
||||
t.Errorf("failed to start test server: %s", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
|
||||
c, err := NewClient("127.0.0.1",
|
||||
WithPort(serverPort),
|
||||
WithTLSPolicy(TLSOpportunistic),
|
||||
WithSMTPAuth(SMTPAuthXOAUTH2),
|
||||
WithUsername("user"),
|
||||
WithPassword("token"))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create new client: %v", err)
|
||||
}
|
||||
if err = c.DialWithContext(context.Background()); err == nil {
|
||||
t.Fatal("expected dial error got nil")
|
||||
} else {
|
||||
if !errors.Is(err, ErrXOauth2AuthNotSupported) {
|
||||
t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err)
|
||||
}
|
||||
}
|
||||
if err = c.Close(); err != nil {
|
||||
t.Fatalf("disconnect from test server failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXOAuth2OK_faker(t *testing.T) {
|
||||
server := []string{
|
||||
"220 Fake server ready ESMTP",
|
||||
"250-fake.server",
|
||||
|
@ -1952,7 +2126,7 @@ func TestXOAuth2OK(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestXOAuth2Unsupported(t *testing.T) {
|
||||
func TestXOAuth2Unsupported_faker(t *testing.T) {
|
||||
server := []string{
|
||||
"220 Fake server ready ESMTP",
|
||||
"250-fake.server",
|
||||
|
@ -2085,7 +2259,6 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
|
|||
|
||||
data, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
fmt.Printf("unable to read from connection: %s\n", err)
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(data, "EHLO") && !strings.HasPrefix(data, "HELO") {
|
||||
|
@ -2093,19 +2266,15 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
|
|||
return
|
||||
}
|
||||
if err = writeLine("250-localhost.localdomain\r\n" + featureSet); err != nil {
|
||||
fmt.Printf("unable to write to connection: %s\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
data, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
fmt.Println("Error reading data:", err)
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
var datastring string
|
||||
data = strings.TrimSpace(data)
|
||||
|
@ -2128,6 +2297,13 @@ func handleTestServerConnection(connection net.Conn, featureSet string, failRese
|
|||
break
|
||||
}
|
||||
writeOK()
|
||||
case strings.HasPrefix(data, "AUTH XOAUTH2"):
|
||||
auth := strings.TrimPrefix(data, "AUTH XOAUTH2 ")
|
||||
if !strings.EqualFold(auth, "dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=") {
|
||||
_ = writeLine("535 5.7.8 Error: authentication failed")
|
||||
break
|
||||
}
|
||||
_ = writeLine("235 2.7.0 Authentication successful")
|
||||
case strings.HasPrefix(data, "AUTH PLAIN"):
|
||||
auth := strings.TrimPrefix(data, "AUTH PLAIN ")
|
||||
if !strings.EqualFold(auth, "AHRvbmlAdGVzdGVyLmNvbQBWM3J5UzNjcjN0Kw==") {
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
//go:build go1.19 && !go1.20
|
||||
// +build go1.19,!go1.20
|
||||
//go:build !go1.20
|
||||
// +build !go1.20
|
||||
|
||||
package mail
|
||||
|
||||
|
|
141
smtp/smtp.go
141
smtp/smtp.go
|
@ -30,34 +30,57 @@ import (
|
|||
"net/textproto"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/wneessen/go-mail/log"
|
||||
)
|
||||
|
||||
// A Client represents a client connection to an SMTP server.
|
||||
type Client struct {
|
||||
// Text is the textproto.Conn used by the Client. It is exported to allow for
|
||||
// clients to add extensions.
|
||||
// Text is the textproto.Conn used by the Client. It is exported to allow for clients to add extensions.
|
||||
Text *textproto.Conn
|
||||
// keep a reference to the connection so it can be used to create a TLS
|
||||
// connection later
|
||||
|
||||
// auth supported auth mechanisms
|
||||
auth []string
|
||||
|
||||
// keep a reference to the connection so it can be used to create a TLS connection later
|
||||
conn net.Conn
|
||||
// whether the Client is using TLS
|
||||
tls bool
|
||||
serverName string
|
||||
// map of supported extensions
|
||||
|
||||
// debug logging is enabled
|
||||
debug bool
|
||||
|
||||
// didHello indicates whether we've said HELO/EHLO
|
||||
didHello bool
|
||||
|
||||
// dsnmrtype defines the mail return option in case DSN is enabled
|
||||
dsnmrtype string
|
||||
|
||||
// dsnrntype defines the recipient notify option in case DSN is enabled
|
||||
dsnrntype string
|
||||
|
||||
// ext is a map of supported extensions
|
||||
ext map[string]string
|
||||
// supported auth mechanisms
|
||||
auth []string
|
||||
localName string // the name to use in HELO/EHLO
|
||||
didHello bool // whether we've said HELO/EHLO
|
||||
helloError error // the error from the hello
|
||||
// debug logging
|
||||
debug bool // debug logging is enabled
|
||||
logger log.Logger // logger will be used for debug logging
|
||||
// DSN support
|
||||
dsnmrtype string // dsnmrtype defines the mail return option in case DSN is enabled
|
||||
dsnrntype string // dsnrntype defines the recipient notify option in case DSN is enabled
|
||||
|
||||
// helloError is the error from the hello
|
||||
helloError error
|
||||
|
||||
// localName is the name to use in HELO/EHLO
|
||||
localName string // the name to use in HELO/EHLO
|
||||
|
||||
// logger will be used for debug logging
|
||||
logger log.Logger
|
||||
|
||||
// mutex is used to synchronize access to shared resources, ensuring that only one goroutine can access
|
||||
// the resource at a time.
|
||||
mutex sync.RWMutex
|
||||
|
||||
// tls indicates whether the Client is using TLS
|
||||
tls bool
|
||||
|
||||
// serverName denotes the name of the server to which the application will connect. Used for
|
||||
// identification and routing.
|
||||
serverName string
|
||||
}
|
||||
|
||||
// Dial returns a new [Client] connected to an SMTP server at addr.
|
||||
|
@ -94,7 +117,10 @@ func NewClient(conn net.Conn, host string) (*Client, error) {
|
|||
|
||||
// Close closes the connection.
|
||||
func (c *Client) Close() error {
|
||||
return c.Text.Close()
|
||||
c.mutex.Lock()
|
||||
err := c.Text.Close()
|
||||
c.mutex.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// hello runs a hello exchange if needed.
|
||||
|
@ -121,28 +147,39 @@ func (c *Client) Hello(localName string) error {
|
|||
if c.didHello {
|
||||
return errors.New("smtp: Hello called after other methods")
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.localName = localName
|
||||
c.mutex.Unlock()
|
||||
|
||||
return c.hello()
|
||||
}
|
||||
|
||||
// cmd is a convenience function that sends a command and returns the response
|
||||
func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) {
|
||||
c.mutex.Lock()
|
||||
|
||||
c.debugLog(log.DirClientToServer, format, args...)
|
||||
id, err := c.Text.Cmd(format, args...)
|
||||
if err != nil {
|
||||
c.mutex.Unlock()
|
||||
return 0, "", err
|
||||
}
|
||||
c.Text.StartResponse(id)
|
||||
defer c.Text.EndResponse(id)
|
||||
code, msg, err := c.Text.ReadResponse(expectCode)
|
||||
c.debugLog(log.DirServerToClient, "%d %s", code, msg)
|
||||
c.Text.EndResponse(id)
|
||||
c.mutex.Unlock()
|
||||
return code, msg, err
|
||||
}
|
||||
|
||||
// helo sends the HELO greeting to the server. It should be used only when the
|
||||
// server does not support ehlo.
|
||||
func (c *Client) helo() error {
|
||||
c.mutex.Lock()
|
||||
c.ext = nil
|
||||
c.mutex.Unlock()
|
||||
|
||||
_, _, err := c.cmd(250, "HELO %s", c.localName)
|
||||
return err
|
||||
}
|
||||
|
@ -157,9 +194,13 @@ func (c *Client) StartTLS(config *tls.Config) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
c.conn = tls.Client(c.conn, config)
|
||||
c.Text = textproto.NewConn(c.conn)
|
||||
c.tls = true
|
||||
c.mutex.Unlock()
|
||||
|
||||
return c.ehlo()
|
||||
}
|
||||
|
||||
|
@ -167,11 +208,15 @@ func (c *Client) StartTLS(config *tls.Config) error {
|
|||
// The return values are their zero values if [Client.StartTLS] did
|
||||
// not succeed.
|
||||
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
tc, ok := c.conn.(*tls.Conn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
return tc.ConnectionState(), true
|
||||
state, ok = tc.ConnectionState(), true
|
||||
return
|
||||
}
|
||||
|
||||
// Verify checks the validity of an email address on the server.
|
||||
|
@ -257,6 +302,8 @@ func (c *Client) Mail(from string) error {
|
|||
return err
|
||||
}
|
||||
cmdStr := "MAIL FROM:<%s>"
|
||||
|
||||
c.mutex.RLock()
|
||||
if c.ext != nil {
|
||||
if _, ok := c.ext["8BITMIME"]; ok {
|
||||
cmdStr += " BODY=8BITMIME"
|
||||
|
@ -269,6 +316,8 @@ func (c *Client) Mail(from string) error {
|
|||
cmdStr += fmt.Sprintf(" RET=%s", c.dsnmrtype)
|
||||
}
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
|
||||
_, _, err := c.cmd(250, cmdStr, from)
|
||||
return err
|
||||
}
|
||||
|
@ -280,7 +329,11 @@ func (c *Client) Rcpt(to string) error {
|
|||
if err := validateLine(to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mutex.RLock()
|
||||
_, ok := c.ext["DSN"]
|
||||
c.mutex.RUnlock()
|
||||
|
||||
if ok && c.dsnrntype != "" {
|
||||
_, _, err := c.cmd(25, "RCPT TO:<%s> NOTIFY=%s", to, c.dsnrntype)
|
||||
return err
|
||||
|
@ -294,12 +347,23 @@ type dataCloser struct {
|
|||
io.WriteCloser
|
||||
}
|
||||
|
||||
// Close releases the lock, closes the WriteCloser, waits for a response, and then returns any error encountered.
|
||||
func (d *dataCloser) Close() error {
|
||||
d.c.mutex.Lock()
|
||||
_ = d.WriteCloser.Close()
|
||||
_, _, err := d.c.Text.ReadResponse(250)
|
||||
d.c.mutex.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Write writes data to the underlying WriteCloser while ensuring thread-safety by locking and unlocking a mutex.
|
||||
func (d *dataCloser) Write(p []byte) (n int, err error) {
|
||||
d.c.mutex.Lock()
|
||||
n, err = d.WriteCloser.Write(p)
|
||||
d.c.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Data issues a DATA command to the server and returns a writer that
|
||||
// can be used to write the mail headers and body. The caller should
|
||||
// close the writer before calling any more methods on c. A call to
|
||||
|
@ -309,7 +373,14 @@ func (c *Client) Data() (io.WriteCloser, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dataCloser{c, c.Text.DotWriter()}, nil
|
||||
datacloser := &dataCloser{}
|
||||
|
||||
c.mutex.Lock()
|
||||
datacloser.c = c
|
||||
datacloser.WriteCloser = c.Text.DotWriter()
|
||||
c.mutex.Unlock()
|
||||
|
||||
return datacloser, nil
|
||||
}
|
||||
|
||||
var testHookStartTLS func(*tls.Config) // nil, except for tests
|
||||
|
@ -405,7 +476,10 @@ func (c *Client) Extension(ext string) (bool, string) {
|
|||
return false, ""
|
||||
}
|
||||
ext = strings.ToUpper(ext)
|
||||
|
||||
c.mutex.RLock()
|
||||
param, ok := c.ext[ext]
|
||||
c.mutex.RUnlock()
|
||||
return ok, param
|
||||
}
|
||||
|
||||
|
@ -438,7 +512,11 @@ func (c *Client) Quit() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Text.Close()
|
||||
c.mutex.Lock()
|
||||
err = c.Text.Close()
|
||||
c.mutex.Unlock()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SetDebugLog enables the debug logging for incoming and outgoing SMTP messages
|
||||
|
@ -472,6 +550,21 @@ func (c *Client) SetDSNRcptNotifyOption(d string) {
|
|||
c.dsnrntype = d
|
||||
}
|
||||
|
||||
// HasConnection checks if the client has an active connection.
|
||||
// Returns true if the `conn` field is not nil, indicating an active connection.
|
||||
func (c *Client) HasConnection() bool {
|
||||
return c.conn != nil
|
||||
}
|
||||
|
||||
func (c *Client) UpdateDeadline(timeout time.Duration) error {
|
||||
c.mutex.Lock()
|
||||
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return fmt.Errorf("smtp: failed to update deadline: %w", err)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// debugLog checks if the debug flag is set and if so logs the provided message to
|
||||
// the log.Logger interface
|
||||
func (c *Client) debugLog(d log.Direction, f string, a ...interface{}) {
|
||||
|
|
|
@ -25,6 +25,9 @@ func (c *Client) ehlo() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
ext := make(map[string]string)
|
||||
extList := strings.Split(msg, "\n")
|
||||
if len(extList) > 1 {
|
||||
|
|
|
@ -22,12 +22,15 @@ import "strings"
|
|||
// should be the preferred greeting for servers that support it.
|
||||
//
|
||||
// Backport of: https://github.com/golang/go/commit/4d8db00641cc9ff4f44de7df9b8c4f4a4f9416ee#diff-4f6f6bdb9891d4dd271f9f31430420a2e44018fe4ee539576faf458bebb3cee4
|
||||
// to guarantee backwards compatibility with Go 1.16/1.17:w
|
||||
// to guarantee backwards compatibility with Go 1.16/1.17
|
||||
func (c *Client) ehlo() error {
|
||||
_, msg, err := c.cmd(250, "EHLO %s", c.localName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
ext := make(map[string]string)
|
||||
extList := strings.Split(msg, "\n")
|
||||
if len(extList) > 1 {
|
||||
|
|
Loading…
Reference in a new issue