diff --git a/client.go b/client.go index 2a20c96..14a3c03 100644 --- a/client.go +++ b/client.go @@ -246,8 +246,12 @@ func (c *Client) SetTLSPolicy(p TLSPolicy) { } // SetTLSConfig overrides the current *tls.Config with the given *tls.Config value -func (c *Client) SetTLSConfig(co *tls.Config) { +func (c *Client) SetTLSConfig(co *tls.Config) error { + if co == nil { + return ErrInvalidTLSConfig + } c.tlsconfig = co + return nil } // SetUsername overrides the current username string with the given value diff --git a/client_test.go b/client_test.go index e4e874f..132382f 100644 --- a/client_test.go +++ b/client_test.go @@ -2,6 +2,7 @@ package mail import ( "crypto/tls" + "fmt" "net/smtp" "testing" "time" @@ -50,6 +51,10 @@ func TestNewClient(t *testing.T) { t.Errorf("failed to create new client. TLS config min versino expected: %d, got: %d", DefaultTLSMinVersion, c.tlsconfig.MinVersion) } + if c.ServerAddr() != fmt.Sprintf("%s:%d", tt.host, c.port) { + t.Errorf("failed to create new client. c.ServerAddr() expected: %s, got: %s", + fmt.Sprintf("%s:%d", tt.host, c.port), c.ServerAddr()) + } }) } } @@ -69,6 +74,7 @@ func TestNewClientWithOptions(t *testing.T) { {"WithTimeout()", WithTimeout(-10), true}, {"WithSSL()", WithSSL(), false}, {"WithHELO()", WithHELO(host), false}, + {"WithHELO(); helo is empty", WithHELO(""), true}, {"WithTLSPolicy()", WithTLSPolicy(TLSOpportunistic), false}, {"WithTLSConfig()", WithTLSConfig(&tls.Config{}), false}, {"WithTLSConfig(); config is nil", WithTLSConfig(nil), true}, @@ -120,14 +126,17 @@ func TestWithPort(t *testing.T) { name string value int want int + sf bool }{ - {"set port to 25", 25, 25}, - {"set port to 465", 465, 465}, + {"set port to 25", 25, 25, false}, + {"set port to 465", 465, 465, false}, + {"set port to 100000", 100000, 25, true}, + {"set port to -10", -10, 25, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(DefaultHost, WithPort(tt.value)) - if err != nil { + if err != nil && !tt.sf { t.Errorf("failed to create new client: %s", err) return } @@ -144,15 +153,17 @@ func TestWithTimeout(t *testing.T) { name string value time.Duration want time.Duration + sf bool }{ - {"set timeout to 5s", time.Second * 5, time.Second * 5}, - {"set timeout to 30s", time.Second * 30, time.Second * 30}, - {"set timeout to 1m", time.Minute, time.Minute}, + {"set timeout to 5s", time.Second * 5, time.Second * 5, false}, + {"set timeout to 30s", time.Second * 30, time.Second * 30, false}, + {"set timeout to 1m", time.Minute, time.Minute, false}, + {"set timeout to 0", 0, DefaultTimeout, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(DefaultHost, WithTimeout(tt.value)) - if err != nil { + if err != nil && !tt.sf { t.Errorf("failed to create new client: %s", err) return } @@ -163,47 +174,27 @@ func TestWithTimeout(t *testing.T) { } } -// TestWithSSL tests the WithSSL() option for the NewClient() method -func TestWithSSL(t *testing.T) { - tests := []struct { - name string - want bool - }{ - {"set SSL to true", true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c, err := NewClient(DefaultHost, WithSSL()) - if err != nil { - t.Errorf("failed to create new client: %s", err) - return - } - if c.ssl != tt.want { - t.Errorf("failed to set SSL. Want: %t, got: %t", tt.want, c.ssl) - } - }) - } -} - // TestWithTLSPolicy tests the WithTLSPolicy() option for the NewClient() method func TestWithTLSPolicy(t *testing.T) { tests := []struct { name string value TLSPolicy - want TLSPolicy + want string + sf bool }{ - {"Policy: TLSMandatory", TLSMandatory, TLSMandatory}, - {"Policy: TLSOpportunistic", TLSOpportunistic, TLSOpportunistic}, - {"Policy: NoTLS", NoTLS, NoTLS}, + {"Policy: TLSMandatory", TLSMandatory, TLSMandatory.String(), false}, + {"Policy: TLSOpportunistic", TLSOpportunistic, TLSOpportunistic.String(), false}, + {"Policy: NoTLS", NoTLS, NoTLS.String(), false}, + {"Policy: Invalid", -1, "UnknownPolicy", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(DefaultHost, WithTLSPolicy(tt.value)) - if err != nil { + if err != nil && !tt.sf { t.Errorf("failed to create new client: %s", err) return } - if c.tlspolicy != tt.want { + if c.tlspolicy.String() != tt.want { t.Errorf("failed to set TLSPolicy. Want: %s, got: %s", tt.want, c.tlspolicy) } }) @@ -215,11 +206,13 @@ func TestSetTLSPolicy(t *testing.T) { tests := []struct { name string value TLSPolicy - want TLSPolicy + want string + sf bool }{ - {"Policy: TLSMandatory", TLSMandatory, TLSMandatory}, - {"Policy: TLSOpportunistic", TLSOpportunistic, TLSOpportunistic}, - {"Policy: NoTLS", NoTLS, NoTLS}, + {"Policy: TLSMandatory", TLSMandatory, TLSMandatory.String(), false}, + {"Policy: TLSOpportunistic", TLSOpportunistic, TLSOpportunistic.String(), false}, + {"Policy: NoTLS", NoTLS, NoTLS.String(), false}, + {"Policy: Invalid", -1, "UnknownPolicy", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -229,7 +222,7 @@ func TestSetTLSPolicy(t *testing.T) { return } c.SetTLSPolicy(tt.value) - if c.tlspolicy != tt.want { + if c.tlspolicy.String() != tt.want { t.Errorf("failed to set TLSPolicy. Want: %s, got: %s", tt.want, c.tlspolicy) } })