Merge pull request #128 from sters/main

Adding WithDialContextFunc client option
This commit is contained in:
Winni Neessen 2023-04-20 10:32:36 +02:00 committed by GitHub
commit 13c8d0a32c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 10 deletions

View file

@ -76,6 +76,9 @@ const (
DSNRcptNotifyDelay DSNRcptNotifyOption = "DELAY"
)
// DialContextFunc is a type to define custom DialContext function.
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
// Client is the SMTP client struct
type Client struct {
// co is the net.Conn that the smtp.Client is based on
@ -137,6 +140,9 @@ type Client struct {
// l is a logger that implements the log.Logger interface
l log.Logger
// dialContextFunc is a custom DialContext function to dial target SMTP server
dialContextFunc DialContextFunc
}
// Option returns a function that can be used for grouping Client options
@ -401,6 +407,14 @@ func WithoutNoop() Option {
}
}
// WithDialContextFunc overrides the default DialContext for connecting SMTP server
func WithDialContextFunc(f DialContextFunc) Option {
return func(c *Client) error {
c.dialContextFunc = f
return nil
}
}
// TLSPolicy returns the currently set TLSPolicy as string
func (c *Client) TLSPolicy() string {
return c.tlspolicy.String()
@ -481,18 +495,18 @@ func (c *Client) DialWithContext(pc context.Context) error {
ctx, cfn := context.WithDeadline(pc, time.Now().Add(c.cto))
defer cfn()
nd := net.Dialer{}
if c.dialContextFunc == nil {
nd := net.Dialer{}
c.dialContextFunc = nd.DialContext
if c.ssl {
td := tls.Dialer{NetDialer: &nd, Config: c.tlsconfig}
c.enc = true
c.dialContextFunc = td.DialContext
}
}
var err error
if c.ssl {
td := tls.Dialer{NetDialer: &nd, Config: c.tlsconfig}
c.enc = true
c.co, err = td.DialContext(ctx, "tcp", c.ServerAddr())
}
if !c.ssl {
c.co, err = nd.DialContext(ctx, "tcp", c.ServerAddr())
}
c.co, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil {
return err
}

View file

@ -9,6 +9,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
@ -108,6 +109,9 @@ func TestNewClientWithOptions(t *testing.T) {
{"WithoutNoop()", WithoutNoop(), false},
{"WithDebugLog()", WithDebugLog(), false},
{"WithLogger()", WithLogger(log.New(os.Stderr, log.LevelDebug)), false},
{"WithDialContextFunc()", WithDialContextFunc(func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, nil
}), false},
{
"WithDSNRcptNotifyType() NEVER combination",
@ -703,6 +707,31 @@ func TestClient_DialWithContextOptions(t *testing.T) {
}
}
// TestClient_DialWithContextOptionDialContextFunc tests the DialWithContext method plus
// use dialContextFunc option for the Client object
func TestClient_DialWithContextOptionDialContextFunc(t *testing.T) {
c, err := getTestConnection(true)
if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err)
}
called := false
c.dialContextFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
called = true
return (&net.Dialer{}).DialContext(ctx, network, address)
}
ctx := context.Background()
if err := c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err)
return
}
if called == false {
t.Errorf("dialContextFunc supposed to be called but not called")
}
}
// TestClient_DialSendClose tests the Dial(), Send() and Close() method of Client
func TestClient_DialSendClose(t *testing.T) {
if os.Getenv("TEST_ALLOW_SEND") == "" {