mirror of
https://github.com/wneessen/go-mail.git
synced 2024-12-18 08:50:39 +01:00
Add tests for new tls and connection handling methods
This commit introduces tests for various TLS-related methods such as GetTLSConnectionState, HasConnection, SetDSNMailReturnOption, SetDSNRcptNotifyOption, and UpdateDeadline. It also modifies the error handling logic in smtp.go to include new error types and improves the mutex handling in UpdateDeadline.
This commit is contained in:
parent
9163943684
commit
159c1bf850
2 changed files with 367 additions and 6 deletions
21
smtp/smtp.go
21
smtp/smtp.go
|
@ -36,7 +36,15 @@ import (
|
|||
"github.com/wneessen/go-mail/log"
|
||||
)
|
||||
|
||||
var ErrNonTLSConnection = errors.New("connection is not using TLS")
|
||||
var (
|
||||
|
||||
// ErrNonTLSConnection is returned when an attempt is made to retrieve TLS state on a non-TLS connection.
|
||||
ErrNonTLSConnection = errors.New("connection is not using TLS")
|
||||
|
||||
// ErrNoConnection is returned when attempting to perform an operation that requires an established
|
||||
// connection but none exists.
|
||||
ErrNoConnection = errors.New("connection is not established")
|
||||
)
|
||||
|
||||
// A Client represents a client connection to an SMTP server.
|
||||
type Client struct {
|
||||
|
@ -570,10 +578,10 @@ func (c *Client) HasConnection() bool {
|
|||
// UpdateDeadline sets a new deadline on the SMTP connection with the specified timeout duration.
|
||||
func (c *Client) UpdateDeadline(timeout time.Duration) error {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return fmt.Errorf("smtp: failed to update deadline: %w", err)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -583,17 +591,18 @@ func (c *Client) GetTLSConnectionState() (*tls.ConnectionState, error) {
|
|||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
|
||||
if !c.isConnected {
|
||||
return nil, ErrNoConnection
|
||||
|
||||
}
|
||||
if !c.tls {
|
||||
return nil, ErrNonTLSConnection
|
||||
}
|
||||
if c.conn == nil {
|
||||
return nil, errors.New("smtp: connection is not established")
|
||||
}
|
||||
if conn, ok := c.conn.(*tls.Conn); ok {
|
||||
cstate := conn.ConnectionState()
|
||||
return &cstate, nil
|
||||
}
|
||||
return nil, errors.New("smtp: connection is not a TLS connection")
|
||||
return nil, errors.New("unable to retrieve TLS connection state")
|
||||
}
|
||||
|
||||
// debugLog checks if the debug flag is set and if so logs the provided message to
|
||||
|
|
|
@ -1640,6 +1640,356 @@ func TestTLSConnState(t *testing.T) {
|
|||
<-serverDone
|
||||
}
|
||||
|
||||
func TestClient_GetTLSConnectionState(t *testing.T) {
|
||||
ln := newLocalListener(t)
|
||||
defer func() {
|
||||
_ = ln.Close()
|
||||
}()
|
||||
clientDone := make(chan bool)
|
||||
serverDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Server accept: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
if err := serverHandle(c, t); err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer close(clientDone)
|
||||
c, err := Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("Client dial: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Quit()
|
||||
}()
|
||||
cfg := &tls.Config{ServerName: "example.com"}
|
||||
testHookStartTLS(cfg) // set the RootCAs
|
||||
if err := c.StartTLS(cfg); err != nil {
|
||||
t.Errorf("StartTLS: %v", err)
|
||||
return
|
||||
}
|
||||
cs, err := c.GetTLSConnectionState()
|
||||
if err != nil {
|
||||
t.Errorf("failed to get TLSConnectionState: %s", err)
|
||||
return
|
||||
}
|
||||
if cs.Version == 0 || !cs.HandshakeComplete {
|
||||
t.Errorf("ConnectionState = %#v; expect non-zero Version and HandshakeComplete", cs)
|
||||
}
|
||||
}()
|
||||
<-clientDone
|
||||
<-serverDone
|
||||
}
|
||||
|
||||
func TestClient_GetTLSConnectionState_noTLS(t *testing.T) {
|
||||
ln := newLocalListener(t)
|
||||
defer func() {
|
||||
_ = ln.Close()
|
||||
}()
|
||||
clientDone := make(chan bool)
|
||||
serverDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Server accept: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
if err := serverHandle(c, t); err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer close(clientDone)
|
||||
c, err := Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("Client dial: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Quit()
|
||||
}()
|
||||
_, err = c.GetTLSConnectionState()
|
||||
if err == nil {
|
||||
t.Error("GetTLSConnectionState: expected error; got nil")
|
||||
return
|
||||
}
|
||||
}()
|
||||
<-clientDone
|
||||
<-serverDone
|
||||
}
|
||||
|
||||
func TestClient_GetTLSConnectionState_noConn(t *testing.T) {
|
||||
ln := newLocalListener(t)
|
||||
defer func() {
|
||||
_ = ln.Close()
|
||||
}()
|
||||
clientDone := make(chan bool)
|
||||
serverDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Server accept: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
if err := serverHandle(c, t); err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer close(clientDone)
|
||||
c, err := Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("Client dial: %v", err)
|
||||
return
|
||||
}
|
||||
_ = c.Close()
|
||||
_, err = c.GetTLSConnectionState()
|
||||
if err == nil {
|
||||
t.Error("GetTLSConnectionState: expected error; got nil")
|
||||
return
|
||||
}
|
||||
}()
|
||||
<-clientDone
|
||||
<-serverDone
|
||||
}
|
||||
|
||||
func TestClient_GetTLSConnectionState_unableErr(t *testing.T) {
|
||||
ln := newLocalListener(t)
|
||||
defer func() {
|
||||
_ = ln.Close()
|
||||
}()
|
||||
clientDone := make(chan bool)
|
||||
serverDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Server accept: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
if err := serverHandle(c, t); err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer close(clientDone)
|
||||
c, err := Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("Client dial: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Quit()
|
||||
}()
|
||||
c.tls = true
|
||||
_, err = c.GetTLSConnectionState()
|
||||
if err == nil {
|
||||
t.Error("GetTLSConnectionState: expected error; got nil")
|
||||
return
|
||||
}
|
||||
}()
|
||||
<-clientDone
|
||||
<-serverDone
|
||||
}
|
||||
func TestClient_HasConnection(t *testing.T) {
|
||||
ln := newLocalListener(t)
|
||||
defer func() {
|
||||
_ = ln.Close()
|
||||
}()
|
||||
clientDone := make(chan bool)
|
||||
serverDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Server accept: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
if err := serverHandle(c, t); err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer close(clientDone)
|
||||
c, err := Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("Client dial: %v", err)
|
||||
return
|
||||
}
|
||||
cfg := &tls.Config{ServerName: "example.com"}
|
||||
testHookStartTLS(cfg) // set the RootCAs
|
||||
if err := c.StartTLS(cfg); err != nil {
|
||||
t.Errorf("StartTLS: %v", err)
|
||||
return
|
||||
}
|
||||
if !c.HasConnection() {
|
||||
t.Error("HasConnection: expected true; got false")
|
||||
return
|
||||
}
|
||||
if err = c.Quit(); err != nil {
|
||||
t.Errorf("closing connection failed: %s", err)
|
||||
return
|
||||
}
|
||||
if c.HasConnection() {
|
||||
t.Error("HasConnection: expected false; got true")
|
||||
}
|
||||
}()
|
||||
<-clientDone
|
||||
<-serverDone
|
||||
}
|
||||
|
||||
func TestClient_SetDSNMailReturnOption(t *testing.T) {
|
||||
ln := newLocalListener(t)
|
||||
defer func() {
|
||||
_ = ln.Close()
|
||||
}()
|
||||
clientDone := make(chan bool)
|
||||
serverDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Server accept: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
if err := serverHandle(c, t); err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer close(clientDone)
|
||||
c, err := Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("Client dial: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Quit()
|
||||
}()
|
||||
c.SetDSNMailReturnOption("foo")
|
||||
if c.dsnmrtype != "foo" {
|
||||
t.Errorf("SetDSNMailReturnOption: expected %s; got %s", "foo", c.dsnrntype)
|
||||
}
|
||||
}()
|
||||
<-clientDone
|
||||
<-serverDone
|
||||
}
|
||||
|
||||
func TestClient_SetDSNRcptNotifyOption(t *testing.T) {
|
||||
ln := newLocalListener(t)
|
||||
defer func() {
|
||||
_ = ln.Close()
|
||||
}()
|
||||
clientDone := make(chan bool)
|
||||
serverDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Server accept: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
if err := serverHandle(c, t); err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer close(clientDone)
|
||||
c, err := Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("Client dial: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Quit()
|
||||
}()
|
||||
c.SetDSNRcptNotifyOption("foo")
|
||||
if c.dsnrntype != "foo" {
|
||||
t.Errorf("SetDSNMailReturnOption: expected %s; got %s", "foo", c.dsnrntype)
|
||||
}
|
||||
}()
|
||||
<-clientDone
|
||||
<-serverDone
|
||||
}
|
||||
|
||||
func TestClient_UpdateDeadline(t *testing.T) {
|
||||
ln := newLocalListener(t)
|
||||
defer func() {
|
||||
_ = ln.Close()
|
||||
}()
|
||||
clientDone := make(chan bool)
|
||||
serverDone := make(chan bool)
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("Server accept: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
if err = serverHandle(c, t); err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer close(clientDone)
|
||||
c, err := Dial(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Errorf("Client dial: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = c.Close()
|
||||
}()
|
||||
if !c.HasConnection() {
|
||||
t.Error("HasConnection: expected true; got false")
|
||||
return
|
||||
}
|
||||
if err = c.UpdateDeadline(time.Millisecond * 20); err != nil {
|
||||
t.Errorf("failed to update deadline: %s", err)
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
if !c.HasConnection() {
|
||||
t.Error("HasConnection: expected true; got false")
|
||||
return
|
||||
}
|
||||
}()
|
||||
<-clientDone
|
||||
<-serverDone
|
||||
}
|
||||
|
||||
func newLocalListener(t *testing.T) net.Listener {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
|
@ -1685,6 +2035,8 @@ func serverHandle(c net.Conn, t *testing.T) error {
|
|||
}
|
||||
config := &tls.Config{Certificates: []tls.Certificate{keypair}}
|
||||
return tf(config)
|
||||
case "QUIT":
|
||||
return nil
|
||||
default:
|
||||
t.Fatalf("unrecognized command: %q", s.Text())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue