diff --git a/client.go b/client.go index 88dafc4..d2e9f95 100644 --- a/client.go +++ b/client.go @@ -28,8 +28,11 @@ type Client struct { // Use SSL for the connection ssl bool - // Sets the client cto use STARTTTLS for the connection (is disabled when SSL is set) - starttls bool + // tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol + tlspolicy TLSPolicy + + // tlsconfig represents the tls.Config setting for the STARTTLS connection + tlsconfig *tls.Config // Timeout for the SMTP server connection cto time.Duration @@ -48,16 +51,18 @@ var ( // ErrNoHostname should be used if a Client has no hostname set ErrNoHostname = errors.New("hostname for client cannot be empty") - // ErrInvalidHostname should be used if a Client has an invalid hostname set - //ErrInvalidHostname = errors.New("hostname for client is invalid") + // ErrNoSTARTTLS should be used if the target server does not support the STARTTLS protocol + ErrNoSTARTTLS = errors.New("target host does not support STARTTLS") ) // NewClient returns a new Session client object func NewClient(h string, o ...Option) (*Client, error) { c := &Client{ - host: h, - port: DefaultPort, - cto: DefaultTimeout, + host: h, + port: DefaultPort, + cto: DefaultTimeout, + tlspolicy: TLSMandatory, + tlsconfig: &tls.Config{ServerName: h}, } // Set default HELO/EHLO hostname @@ -110,6 +115,36 @@ func WithHELO(h string) Option { } } +// TLSPolicy returns the currently set TLSPolicy as string +func (c *Client) TLSPolicy() string { + return fmt.Sprintf("%s", c.tlspolicy) +} + +// SetTLSPolicy overrides the current TLSPolicy with the given TLSPolicy value +func (c *Client) SetTLSPolicy(p TLSPolicy) { + c.tlspolicy = p +} + +// Send sends out the mail message +func (c *Client) Send() error { + return nil +} + +// Close closes the connection cto the SMTP server +func (c *Client) Close() error { + return c.sc.Close() +} + +// setDefaultHelo retrieves the current hostname and sets it as HELO/EHLO hostname +func (c *Client) setDefaultHelo() error { + hn, err := os.Hostname() + if err != nil { + return fmt.Errorf("failed cto read local hostname: %w", err) + } + c.helo = hn + return nil +} + // Dial establishes a connection cto the SMTP server with a default context.Background func (c *Client) Dial() error { ctx := context.Background() @@ -143,25 +178,14 @@ func (c *Client) DialWithContext(uctx context.Context) error { return err } - return nil -} - -// Send sends out the mail message -func (c *Client) Send() error { - return nil -} - -// Close closes the connection cto the SMTP server -func (c *Client) Close() error { - return c.sc.Close() -} - -// setDefaultHelo retrieves the current hostname and sets it as HELO/EHLO hostname -func (c *Client) setDefaultHelo() error { - hn, err := os.Hostname() - if err != nil { - return fmt.Errorf("failed cto read local hostname: %w", err) + if !c.ssl && c.tlspolicy != NoTLS { + if ok, _ := c.sc.Extension("STARTTLS"); !ok { + return ErrNoSTARTTLS + } + if err := c.sc.StartTLS(c.tlsconfig); err != nil { + return err + } } - c.helo = hn + return nil } diff --git a/client_test.go b/client_test.go index 572e094..5ec156b 100644 --- a/client_test.go +++ b/client_test.go @@ -2,12 +2,13 @@ package mail import ( "testing" + "time" ) // DefaultHost is used as default hostname for the Client const DefaultHost = "localhost" -// TestWithHELo tests the WithHELO() option for the NewClient() method +// TestWithHELO tests the WithHELO() option for the NewClient() method func TestWithHELO(t *testing.T) { tests := []struct { name string @@ -29,3 +30,74 @@ func TestWithHELO(t *testing.T) { }) } } + +// TestWithPort tests the WithPort() option for the NewClient() method +func TestWithPort(t *testing.T) { + tests := []struct { + name string + value int + want int + }{ + {"set port to 25", 25, 25}, + {"set port to 465", 465, 465}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(DefaultHost, WithPort(tt.value)) + if err != nil { + t.Errorf("failed to create new client: %s", err) + return + } + if c.port != tt.want { + t.Errorf("failed to set custom port. Want: %d, got: %d", tt.want, c.port) + } + }) + } +} + +// TestWithTimeout tests the WithTimeout() option for the NewClient() method +func TestWithTimeout(t *testing.T) { + tests := []struct { + name string + value time.Duration + want time.Duration + }{ + {"set timeout to 5s", time.Second * 5, time.Second * 5}, + {"set timeout to 30s", time.Second * 30, time.Second * 30}, + {"set timeout to 1m", time.Minute, time.Minute}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(DefaultHost, WithTimeout(tt.value)) + if err != nil { + t.Errorf("failed to create new client: %s", err) + return + } + if c.cto != tt.want { + t.Errorf("failed to set custom timeout. Want: %d, got: %d", tt.want, c.cto) + } + }) + } +} + +// TestWithSSL tests the WithSSL() option for the NewClient() method +func TestWithSSL(t *testing.T) { + tests := []struct { + name string + want bool + }{ + {"set SSL to true", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(DefaultHost, WithSSL()) + if err != nil { + t.Errorf("failed to create new client: %s", err) + return + } + if c.ssl != tt.want { + t.Errorf("failed to set SSL. Want: %t, got: %t", tt.want, c.ssl) + } + }) + } +} diff --git a/cmd/main.go b/cmd/main.go index f43139a..a0497e8 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -9,7 +9,7 @@ import ( ) func main() { - c, err := mail.NewClient("localhost", mail.WithTimeout(time.Millisecond*500)) + c, err := mail.NewClient("manjaro-vm.fritz.box", mail.WithTimeout(time.Millisecond*500)) if err != nil { fmt.Printf("failed to create new client: %s\n", err) os.Exit(1) @@ -23,9 +23,5 @@ func main() { os.Exit(1) } fmt.Printf("Client: %+v\n", c) - time.Sleep(time.Millisecond * 1500) - if err := c.Close(); err != nil { - fmt.Printf("failed to close SMTP connection: %s\n", err) - os.Exit(1) - } + fmt.Printf("StartTLS policy: %s\n", c.TLSPolicy()) } diff --git a/tls.go b/tls.go index e41a4ff..62414d9 100644 --- a/tls.go +++ b/tls.go @@ -17,3 +17,17 @@ const ( // NoTLS forces the transaction cto be not encrypted NoTLS ) + +// String is a standard method to convert a TLSPolicy into a printable format +func (p TLSPolicy) String() string { + switch p { + case TLSMandatory: + return "TLSMandatory" + case TLSOpportunistic: + return "TLSOpportunistic" + case NoTLS: + return "NoTLS" + default: + return "UnknownPolicy" + } +}