diff --git a/smtp/smtp.go b/smtp/smtp.go index 444b203..7c39996 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -36,7 +36,15 @@ import ( "github.com/wneessen/go-mail/log" ) -var ErrNonTLSConnection = errors.New("connection is not using TLS") +var ( + + // ErrNonTLSConnection is returned when an attempt is made to retrieve TLS state on a non-TLS connection. + ErrNonTLSConnection = errors.New("connection is not using TLS") + + // ErrNoConnection is returned when attempting to perform an operation that requires an established + // connection but none exists. + ErrNoConnection = errors.New("connection is not established") +) // A Client represents a client connection to an SMTP server. type Client struct { @@ -570,10 +578,10 @@ func (c *Client) HasConnection() bool { // UpdateDeadline sets a new deadline on the SMTP connection with the specified timeout duration. func (c *Client) UpdateDeadline(timeout time.Duration) error { c.mutex.Lock() + defer c.mutex.Unlock() if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { return fmt.Errorf("smtp: failed to update deadline: %w", err) } - c.mutex.Unlock() return nil } @@ -583,17 +591,18 @@ func (c *Client) GetTLSConnectionState() (*tls.ConnectionState, error) { c.mutex.RLock() defer c.mutex.RUnlock() + if !c.isConnected { + return nil, ErrNoConnection + + } if !c.tls { return nil, ErrNonTLSConnection } - if c.conn == nil { - return nil, errors.New("smtp: connection is not established") - } if conn, ok := c.conn.(*tls.Conn); ok { cstate := conn.ConnectionState() return &cstate, nil } - return nil, errors.New("smtp: connection is not a TLS connection") + return nil, errors.New("unable to retrieve TLS connection state") } // debugLog checks if the debug flag is set and if so logs the provided message to diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 451a349..7b53963 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1640,6 +1640,356 @@ func TestTLSConnState(t *testing.T) { <-serverDone } +func TestClient_GetTLSConnectionState(t *testing.T) { + ln := newLocalListener(t) + defer func() { + _ = ln.Close() + }() + clientDone := make(chan bool) + serverDone := make(chan bool) + go func() { + defer close(serverDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Server accept: %v", err) + return + } + defer func() { + _ = c.Close() + }() + if err := serverHandle(c, t); err != nil { + t.Errorf("server error: %v", err) + } + }() + go func() { + defer close(clientDone) + c, err := Dial(ln.Addr().String()) + if err != nil { + t.Errorf("Client dial: %v", err) + return + } + defer func() { + _ = c.Quit() + }() + cfg := &tls.Config{ServerName: "example.com"} + testHookStartTLS(cfg) // set the RootCAs + if err := c.StartTLS(cfg); err != nil { + t.Errorf("StartTLS: %v", err) + return + } + cs, err := c.GetTLSConnectionState() + if err != nil { + t.Errorf("failed to get TLSConnectionState: %s", err) + return + } + if cs.Version == 0 || !cs.HandshakeComplete { + t.Errorf("ConnectionState = %#v; expect non-zero Version and HandshakeComplete", cs) + } + }() + <-clientDone + <-serverDone +} + +func TestClient_GetTLSConnectionState_noTLS(t *testing.T) { + ln := newLocalListener(t) + defer func() { + _ = ln.Close() + }() + clientDone := make(chan bool) + serverDone := make(chan bool) + go func() { + defer close(serverDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Server accept: %v", err) + return + } + defer func() { + _ = c.Close() + }() + if err := serverHandle(c, t); err != nil { + t.Errorf("server error: %v", err) + } + }() + go func() { + defer close(clientDone) + c, err := Dial(ln.Addr().String()) + if err != nil { + t.Errorf("Client dial: %v", err) + return + } + defer func() { + _ = c.Quit() + }() + _, err = c.GetTLSConnectionState() + if err == nil { + t.Error("GetTLSConnectionState: expected error; got nil") + return + } + }() + <-clientDone + <-serverDone +} + +func TestClient_GetTLSConnectionState_noConn(t *testing.T) { + ln := newLocalListener(t) + defer func() { + _ = ln.Close() + }() + clientDone := make(chan bool) + serverDone := make(chan bool) + go func() { + defer close(serverDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Server accept: %v", err) + return + } + defer func() { + _ = c.Close() + }() + if err := serverHandle(c, t); err != nil { + t.Errorf("server error: %v", err) + } + }() + go func() { + defer close(clientDone) + c, err := Dial(ln.Addr().String()) + if err != nil { + t.Errorf("Client dial: %v", err) + return + } + _ = c.Close() + _, err = c.GetTLSConnectionState() + if err == nil { + t.Error("GetTLSConnectionState: expected error; got nil") + return + } + }() + <-clientDone + <-serverDone +} + +func TestClient_GetTLSConnectionState_unableErr(t *testing.T) { + ln := newLocalListener(t) + defer func() { + _ = ln.Close() + }() + clientDone := make(chan bool) + serverDone := make(chan bool) + go func() { + defer close(serverDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Server accept: %v", err) + return + } + defer func() { + _ = c.Close() + }() + if err := serverHandle(c, t); err != nil { + t.Errorf("server error: %v", err) + } + }() + go func() { + defer close(clientDone) + c, err := Dial(ln.Addr().String()) + if err != nil { + t.Errorf("Client dial: %v", err) + return + } + defer func() { + _ = c.Quit() + }() + c.tls = true + _, err = c.GetTLSConnectionState() + if err == nil { + t.Error("GetTLSConnectionState: expected error; got nil") + return + } + }() + <-clientDone + <-serverDone +} +func TestClient_HasConnection(t *testing.T) { + ln := newLocalListener(t) + defer func() { + _ = ln.Close() + }() + clientDone := make(chan bool) + serverDone := make(chan bool) + go func() { + defer close(serverDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Server accept: %v", err) + return + } + defer func() { + _ = c.Close() + }() + if err := serverHandle(c, t); err != nil { + t.Errorf("server error: %v", err) + } + }() + go func() { + defer close(clientDone) + c, err := Dial(ln.Addr().String()) + if err != nil { + t.Errorf("Client dial: %v", err) + return + } + cfg := &tls.Config{ServerName: "example.com"} + testHookStartTLS(cfg) // set the RootCAs + if err := c.StartTLS(cfg); err != nil { + t.Errorf("StartTLS: %v", err) + return + } + if !c.HasConnection() { + t.Error("HasConnection: expected true; got false") + return + } + if err = c.Quit(); err != nil { + t.Errorf("closing connection failed: %s", err) + return + } + if c.HasConnection() { + t.Error("HasConnection: expected false; got true") + } + }() + <-clientDone + <-serverDone +} + +func TestClient_SetDSNMailReturnOption(t *testing.T) { + ln := newLocalListener(t) + defer func() { + _ = ln.Close() + }() + clientDone := make(chan bool) + serverDone := make(chan bool) + go func() { + defer close(serverDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Server accept: %v", err) + return + } + defer func() { + _ = c.Close() + }() + if err := serverHandle(c, t); err != nil { + t.Errorf("server error: %v", err) + } + }() + go func() { + defer close(clientDone) + c, err := Dial(ln.Addr().String()) + if err != nil { + t.Errorf("Client dial: %v", err) + return + } + defer func() { + _ = c.Quit() + }() + c.SetDSNMailReturnOption("foo") + if c.dsnmrtype != "foo" { + t.Errorf("SetDSNMailReturnOption: expected %s; got %s", "foo", c.dsnrntype) + } + }() + <-clientDone + <-serverDone +} + +func TestClient_SetDSNRcptNotifyOption(t *testing.T) { + ln := newLocalListener(t) + defer func() { + _ = ln.Close() + }() + clientDone := make(chan bool) + serverDone := make(chan bool) + go func() { + defer close(serverDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Server accept: %v", err) + return + } + defer func() { + _ = c.Close() + }() + if err := serverHandle(c, t); err != nil { + t.Errorf("server error: %v", err) + } + }() + go func() { + defer close(clientDone) + c, err := Dial(ln.Addr().String()) + if err != nil { + t.Errorf("Client dial: %v", err) + return + } + defer func() { + _ = c.Quit() + }() + c.SetDSNRcptNotifyOption("foo") + if c.dsnrntype != "foo" { + t.Errorf("SetDSNMailReturnOption: expected %s; got %s", "foo", c.dsnrntype) + } + }() + <-clientDone + <-serverDone +} + +func TestClient_UpdateDeadline(t *testing.T) { + ln := newLocalListener(t) + defer func() { + _ = ln.Close() + }() + clientDone := make(chan bool) + serverDone := make(chan bool) + go func() { + defer close(serverDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Server accept: %v", err) + return + } + defer func() { + _ = c.Close() + }() + if err = serverHandle(c, t); err != nil { + t.Errorf("server error: %v", err) + } + }() + go func() { + defer close(clientDone) + c, err := Dial(ln.Addr().String()) + if err != nil { + t.Errorf("Client dial: %v", err) + return + } + defer func() { + _ = c.Close() + }() + if !c.HasConnection() { + t.Error("HasConnection: expected true; got false") + return + } + if err = c.UpdateDeadline(time.Millisecond * 20); err != nil { + t.Errorf("failed to update deadline: %s", err) + return + } + time.Sleep(time.Millisecond * 50) + if !c.HasConnection() { + t.Error("HasConnection: expected true; got false") + return + } + }() + <-clientDone + <-serverDone +} + func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -1685,6 +2035,8 @@ func serverHandle(c net.Conn, t *testing.T) error { } config := &tls.Config{Certificates: []tls.Certificate{keypair}} return tf(config) + case "QUIT": + return nil default: t.Fatalf("unrecognized command: %q", s.Text()) }