Adding WithDialContextFunc client option

This commit is contained in:
sters 2023-04-19 23:20:33 +09:00
parent 3a528f1d81
commit e757327e1d
No known key found for this signature in database
GPG key ID: 1C0EBF65A4324DEF
2 changed files with 53 additions and 10 deletions

View file

@ -76,6 +76,9 @@ const (
DSNRcptNotifyDelay DSNRcptNotifyOption = "DELAY" DSNRcptNotifyDelay DSNRcptNotifyOption = "DELAY"
) )
// DialContextFunc is a type to define custom DialContext function.
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
// Client is the SMTP client struct // Client is the SMTP client struct
type Client struct { type Client struct {
// co is the net.Conn that the smtp.Client is based on // co is the net.Conn that the smtp.Client is based on
@ -137,6 +140,9 @@ type Client struct {
// l is a logger that implements the log.Logger interface // l is a logger that implements the log.Logger interface
l log.Logger l log.Logger
// dialContextFunc is a custom DialContext function to dial target SMTP server
dialContextFunc DialContextFunc
} }
// Option returns a function that can be used for grouping Client options // Option returns a function that can be used for grouping Client options
@ -401,6 +407,14 @@ func WithoutNoop() Option {
} }
} }
// WithDialContextFunc overrides the default DialContext for connecting SMTP server
func WithDialContextFunc(f DialContextFunc) Option {
return func(c *Client) error {
c.dialContextFunc = f
return nil
}
}
// TLSPolicy returns the currently set TLSPolicy as string // TLSPolicy returns the currently set TLSPolicy as string
func (c *Client) TLSPolicy() string { func (c *Client) TLSPolicy() string {
return c.tlspolicy.String() return c.tlspolicy.String()
@ -481,18 +495,18 @@ func (c *Client) DialWithContext(pc context.Context) error {
ctx, cfn := context.WithDeadline(pc, time.Now().Add(c.cto)) ctx, cfn := context.WithDeadline(pc, time.Now().Add(c.cto))
defer cfn() defer cfn()
if c.dialContextFunc == nil {
nd := net.Dialer{} nd := net.Dialer{}
c.dialContextFunc = nd.DialContext
var err error
if c.ssl { if c.ssl {
td := tls.Dialer{NetDialer: &nd, Config: c.tlsconfig} td := tls.Dialer{NetDialer: &nd, Config: c.tlsconfig}
c.enc = true c.enc = true
c.co, err = td.DialContext(ctx, "tcp", c.ServerAddr()) c.dialContextFunc = td.DialContext
} }
if !c.ssl {
c.co, err = nd.DialContext(ctx, "tcp", c.ServerAddr())
} }
var err error
c.co, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil { if err != nil {
return err return err
} }

View file

@ -9,6 +9,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@ -108,6 +109,9 @@ func TestNewClientWithOptions(t *testing.T) {
{"WithoutNoop()", WithoutNoop(), false}, {"WithoutNoop()", WithoutNoop(), false},
{"WithDebugLog()", WithDebugLog(), false}, {"WithDebugLog()", WithDebugLog(), false},
{"WithLogger()", WithLogger(log.New(os.Stderr, log.LevelDebug)), false}, {"WithLogger()", WithLogger(log.New(os.Stderr, log.LevelDebug)), false},
{"WithDialContextFunc()", WithDialContextFunc(func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, nil
}), false},
{ {
"WithDSNRcptNotifyType() NEVER combination", "WithDSNRcptNotifyType() NEVER combination",
@ -703,6 +707,31 @@ func TestClient_DialWithContextOptions(t *testing.T) {
} }
} }
// TestClient_DialWithContextOptionDialContextFunc tests the DialWithContext method plus
// use dialContextFunc option for the Client object
func TestClient_DialWithContextOptionDialContextFunc(t *testing.T) {
c, err := getTestConnection(true)
if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err)
}
called := false
c.dialContextFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
called = true
return (&net.Dialer{}).DialContext(ctx, network, address)
}
ctx := context.Background()
if err := c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err)
return
}
if called == false {
t.Errorf("dialContextFunc supposed to be called but not called")
}
}
// TestClient_DialSendClose tests the Dial(), Send() and Close() method of Client // TestClient_DialSendClose tests the Dial(), Send() and Close() method of Client
func TestClient_DialSendClose(t *testing.T) { func TestClient_DialSendClose(t *testing.T) {
if os.Getenv("TEST_ALLOW_SEND") == "" { if os.Getenv("TEST_ALLOW_SEND") == "" {