mirror of
https://github.com/wneessen/go-mail.git
synced 2024-11-22 22:00:49 +01:00
Merge pull request #128 from sters/main
Adding WithDialContextFunc client option
This commit is contained in:
commit
13c8d0a32c
2 changed files with 53 additions and 10 deletions
24
client.go
24
client.go
|
@ -76,6 +76,9 @@ const (
|
||||||
DSNRcptNotifyDelay DSNRcptNotifyOption = "DELAY"
|
DSNRcptNotifyDelay DSNRcptNotifyOption = "DELAY"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DialContextFunc is a type to define custom DialContext function.
|
||||||
|
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
|
||||||
// Client is the SMTP client struct
|
// Client is the SMTP client struct
|
||||||
type Client struct {
|
type Client struct {
|
||||||
// co is the net.Conn that the smtp.Client is based on
|
// co is the net.Conn that the smtp.Client is based on
|
||||||
|
@ -137,6 +140,9 @@ type Client struct {
|
||||||
|
|
||||||
// l is a logger that implements the log.Logger interface
|
// l is a logger that implements the log.Logger interface
|
||||||
l log.Logger
|
l log.Logger
|
||||||
|
|
||||||
|
// dialContextFunc is a custom DialContext function to dial target SMTP server
|
||||||
|
dialContextFunc DialContextFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option returns a function that can be used for grouping Client options
|
// Option returns a function that can be used for grouping Client options
|
||||||
|
@ -401,6 +407,14 @@ func WithoutNoop() Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithDialContextFunc overrides the default DialContext for connecting SMTP server
|
||||||
|
func WithDialContextFunc(f DialContextFunc) Option {
|
||||||
|
return func(c *Client) error {
|
||||||
|
c.dialContextFunc = f
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TLSPolicy returns the currently set TLSPolicy as string
|
// TLSPolicy returns the currently set TLSPolicy as string
|
||||||
func (c *Client) TLSPolicy() string {
|
func (c *Client) TLSPolicy() string {
|
||||||
return c.tlspolicy.String()
|
return c.tlspolicy.String()
|
||||||
|
@ -481,18 +495,18 @@ func (c *Client) DialWithContext(pc context.Context) error {
|
||||||
ctx, cfn := context.WithDeadline(pc, time.Now().Add(c.cto))
|
ctx, cfn := context.WithDeadline(pc, time.Now().Add(c.cto))
|
||||||
defer cfn()
|
defer cfn()
|
||||||
|
|
||||||
|
if c.dialContextFunc == nil {
|
||||||
nd := net.Dialer{}
|
nd := net.Dialer{}
|
||||||
|
c.dialContextFunc = nd.DialContext
|
||||||
|
|
||||||
var err error
|
|
||||||
if c.ssl {
|
if c.ssl {
|
||||||
td := tls.Dialer{NetDialer: &nd, Config: c.tlsconfig}
|
td := tls.Dialer{NetDialer: &nd, Config: c.tlsconfig}
|
||||||
|
|
||||||
c.enc = true
|
c.enc = true
|
||||||
c.co, err = td.DialContext(ctx, "tcp", c.ServerAddr())
|
c.dialContextFunc = td.DialContext
|
||||||
}
|
}
|
||||||
if !c.ssl {
|
|
||||||
c.co, err = nd.DialContext(ctx, "tcp", c.ServerAddr())
|
|
||||||
}
|
}
|
||||||
|
var err error
|
||||||
|
c.co, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -108,6 +109,9 @@ func TestNewClientWithOptions(t *testing.T) {
|
||||||
{"WithoutNoop()", WithoutNoop(), false},
|
{"WithoutNoop()", WithoutNoop(), false},
|
||||||
{"WithDebugLog()", WithDebugLog(), false},
|
{"WithDebugLog()", WithDebugLog(), false},
|
||||||
{"WithLogger()", WithLogger(log.New(os.Stderr, log.LevelDebug)), false},
|
{"WithLogger()", WithLogger(log.New(os.Stderr, log.LevelDebug)), false},
|
||||||
|
{"WithDialContextFunc()", WithDialContextFunc(func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return nil, nil
|
||||||
|
}), false},
|
||||||
|
|
||||||
{
|
{
|
||||||
"WithDSNRcptNotifyType() NEVER combination",
|
"WithDSNRcptNotifyType() NEVER combination",
|
||||||
|
@ -703,6 +707,31 @@ func TestClient_DialWithContextOptions(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestClient_DialWithContextOptionDialContextFunc tests the DialWithContext method plus
|
||||||
|
// use dialContextFunc option for the Client object
|
||||||
|
func TestClient_DialWithContextOptionDialContextFunc(t *testing.T) {
|
||||||
|
c, err := getTestConnection(true)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("failed to create test client: %s. Skipping tests", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
c.dialContextFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
called = true
|
||||||
|
return (&net.Dialer{}).DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := c.DialWithContext(ctx); err != nil {
|
||||||
|
t.Errorf("failed to dial with context: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if called == false {
|
||||||
|
t.Errorf("dialContextFunc supposed to be called but not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestClient_DialSendClose tests the Dial(), Send() and Close() method of Client
|
// TestClient_DialSendClose tests the Dial(), Send() and Close() method of Client
|
||||||
func TestClient_DialSendClose(t *testing.T) {
|
func TestClient_DialSendClose(t *testing.T) {
|
||||||
if os.Getenv("TEST_ALLOW_SEND") == "" {
|
if os.Getenv("TEST_ALLOW_SEND") == "" {
|
||||||
|
|
Loading…
Reference in a new issue