diff --git a/client.go b/client.go index bc3ac60..ae54fd4 100644 --- a/client.go +++ b/client.go @@ -76,6 +76,9 @@ const ( DSNRcptNotifyDelay DSNRcptNotifyOption = "DELAY" ) +// DialContextFunc is a type to define custom DialContext function. +type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) + // Client is the SMTP client struct type Client struct { // co is the net.Conn that the smtp.Client is based on @@ -137,6 +140,9 @@ type Client struct { // l is a logger that implements the log.Logger interface l log.Logger + + // dialContextFunc is a custom DialContext function to dial target SMTP server + dialContextFunc DialContextFunc } // Option returns a function that can be used for grouping Client options @@ -401,6 +407,14 @@ func WithoutNoop() Option { } } +// WithDialContextFunc overrides the default DialContext for connecting SMTP server +func WithDialContextFunc(f DialContextFunc) Option { + return func(c *Client) error { + c.dialContextFunc = f + return nil + } +} + // TLSPolicy returns the currently set TLSPolicy as string func (c *Client) TLSPolicy() string { return c.tlspolicy.String() @@ -481,18 +495,18 @@ func (c *Client) DialWithContext(pc context.Context) error { ctx, cfn := context.WithDeadline(pc, time.Now().Add(c.cto)) defer cfn() - nd := net.Dialer{} + if c.dialContextFunc == nil { + nd := net.Dialer{} + c.dialContextFunc = nd.DialContext + if c.ssl { + td := tls.Dialer{NetDialer: &nd, Config: c.tlsconfig} + c.enc = true + c.dialContextFunc = td.DialContext + } + } var err error - if c.ssl { - td := tls.Dialer{NetDialer: &nd, Config: c.tlsconfig} - - c.enc = true - c.co, err = td.DialContext(ctx, "tcp", c.ServerAddr()) - } - if !c.ssl { - c.co, err = nd.DialContext(ctx, "tcp", c.ServerAddr()) - } + c.co, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr()) if err != nil { return err } diff --git a/client_test.go b/client_test.go index b407f55..8bad1e0 100644 --- a/client_test.go +++ b/client_test.go @@ -9,6 +9,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "os" "strconv" "strings" @@ -108,6 +109,9 @@ func TestNewClientWithOptions(t *testing.T) { {"WithoutNoop()", WithoutNoop(), false}, {"WithDebugLog()", WithDebugLog(), false}, {"WithLogger()", WithLogger(log.New(os.Stderr, log.LevelDebug)), false}, + {"WithDialContextFunc()", WithDialContextFunc(func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, nil + }), false}, { "WithDSNRcptNotifyType() NEVER combination", @@ -703,6 +707,31 @@ func TestClient_DialWithContextOptions(t *testing.T) { } } +// TestClient_DialWithContextOptionDialContextFunc tests the DialWithContext method plus +// use dialContextFunc option for the Client object +func TestClient_DialWithContextOptionDialContextFunc(t *testing.T) { + c, err := getTestConnection(true) + if err != nil { + t.Skipf("failed to create test client: %s. Skipping tests", err) + } + + called := false + c.dialContextFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + called = true + return (&net.Dialer{}).DialContext(ctx, network, address) + } + + ctx := context.Background() + if err := c.DialWithContext(ctx); err != nil { + t.Errorf("failed to dial with context: %s", err) + return + } + + if called == false { + t.Errorf("dialContextFunc supposed to be called but not called") + } +} + // TestClient_DialSendClose tests the Dial(), Send() and Close() method of Client func TestClient_DialSendClose(t *testing.T) { if os.Getenv("TEST_ALLOW_SEND") == "" {