Refactor test server code for thread safety

Moved serverProps outside goroutines to improve code readability and maintainability. Added a RWMutex to serverProps to ensure thread-safe access to EchoBuffer, preventing race conditions during concurrent writes.
This commit is contained in:
Winni Neessen 2024-11-11 19:53:14 +01:00
parent f7bdd8fffc
commit 800c266ccb
Signed by: wneessen
GPG key ID: 385AC9889632126E

View file

@ -31,6 +31,7 @@ import (
"net" "net"
"os" "os"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -2235,17 +2236,17 @@ func TestClient_Mail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load()) serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-8BITMIME\r\n250 STARTTLS" featureSet := "250-8BITMIME\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil) echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) { props := &serverProps{
if err := simpleSMTPServer(ctx, t, &serverProps{ EchoBuffer: echoBuffer,
EchoBuffer: buf,
FeatureSet: featureSet, FeatureSet: featureSet,
ListenPort: serverPort, ListenPort: serverPort,
}, }
); err != nil { go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err) t.Errorf("failed to start test server: %s", err)
return return
} }
}(echoBuffer) }()
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 30)
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
@ -2261,7 +2262,9 @@ func TestClient_Mail(t *testing.T) {
t.Errorf("failed to set mail from address: %s", err) t.Errorf("failed to set mail from address: %s", err)
} }
expected := "MAIL FROM:<valid-from@domain.tld> BODY=8BITMIME" expected := "MAIL FROM:<valid-from@domain.tld> BODY=8BITMIME"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n") resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[5], expected) { if !strings.EqualFold(resp[5], expected) {
t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5])
} }
@ -2273,17 +2276,17 @@ func TestClient_Mail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load()) serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-SMTPUTF8\r\n250 STARTTLS" featureSet := "250-SMTPUTF8\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil) echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) { props := &serverProps{
if err := simpleSMTPServer(ctx, t, &serverProps{ EchoBuffer: echoBuffer,
EchoBuffer: buf,
FeatureSet: featureSet, FeatureSet: featureSet,
ListenPort: serverPort, ListenPort: serverPort,
}, }
); err != nil { go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err) t.Errorf("failed to start test server: %s", err)
return return
} }
}(echoBuffer) }()
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 30)
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
@ -2299,7 +2302,9 @@ func TestClient_Mail(t *testing.T) {
t.Errorf("failed to set mail from address: %s", err) t.Errorf("failed to set mail from address: %s", err)
} }
expected := "MAIL FROM:<valid-from@domain.tld> SMTPUTF8" expected := "MAIL FROM:<valid-from@domain.tld> SMTPUTF8"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n") resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[5], expected) { if !strings.EqualFold(resp[5], expected) {
t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5])
} }
@ -2311,17 +2316,17 @@ func TestClient_Mail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load()) serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-SMTPUTF8\r\n250 STARTTLS" featureSet := "250-SMTPUTF8\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil) echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) { props := &serverProps{
if err := simpleSMTPServer(ctx, t, &serverProps{ EchoBuffer: echoBuffer,
EchoBuffer: buf,
FeatureSet: featureSet, FeatureSet: featureSet,
ListenPort: serverPort, ListenPort: serverPort,
}, }
); err != nil { go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err) t.Errorf("failed to start test server: %s", err)
return return
} }
}(echoBuffer) }()
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 30)
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
@ -2337,7 +2342,9 @@ func TestClient_Mail(t *testing.T) {
t.Errorf("failed to set mail from address: %s", err) t.Errorf("failed to set mail from address: %s", err)
} }
expected := "MAIL FROM:<valid-from+📧@domain.tld> SMTPUTF8" expected := "MAIL FROM:<valid-from+📧@domain.tld> SMTPUTF8"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n") resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[5], expected) { if !strings.EqualFold(resp[5], expected) {
t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5])
} }
@ -2349,17 +2356,17 @@ func TestClient_Mail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load()) serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-DSN\r\n250 STARTTLS" featureSet := "250-DSN\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil) echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) { props := &serverProps{
if err := simpleSMTPServer(ctx, t, &serverProps{ EchoBuffer: echoBuffer,
EchoBuffer: buf,
FeatureSet: featureSet, FeatureSet: featureSet,
ListenPort: serverPort, ListenPort: serverPort,
}, }
); err != nil { go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err) t.Errorf("failed to start test server: %s", err)
return return
} }
}(echoBuffer) }()
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 30)
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
@ -2376,7 +2383,9 @@ func TestClient_Mail(t *testing.T) {
t.Errorf("failed to set mail from address: %s", err) t.Errorf("failed to set mail from address: %s", err)
} }
expected := "MAIL FROM:<valid-from@domain.tld> RET=FULL" expected := "MAIL FROM:<valid-from@domain.tld> RET=FULL"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n") resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[5], expected) { if !strings.EqualFold(resp[5], expected) {
t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5])
} }
@ -2388,17 +2397,17 @@ func TestClient_Mail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load()) serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS" featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil) echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) { props := &serverProps{
if err := simpleSMTPServer(ctx, t, &serverProps{ EchoBuffer: echoBuffer,
EchoBuffer: buf,
FeatureSet: featureSet, FeatureSet: featureSet,
ListenPort: serverPort, ListenPort: serverPort,
}, }
); err != nil { go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err) t.Errorf("failed to start test server: %s", err)
return return
} }
}(echoBuffer) }()
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 30)
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
@ -2415,7 +2424,9 @@ func TestClient_Mail(t *testing.T) {
t.Errorf("failed to set mail from address: %s", err) t.Errorf("failed to set mail from address: %s", err)
} }
expected := "MAIL FROM:<valid-from@domain.tld> BODY=8BITMIME SMTPUTF8 RET=FULL" expected := "MAIL FROM:<valid-from@domain.tld> BODY=8BITMIME SMTPUTF8 RET=FULL"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n") resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[7], expected) { if !strings.EqualFold(resp[7], expected) {
t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[7]) t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[7])
} }
@ -2490,17 +2501,17 @@ func TestClient_Rcpt(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load()) serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-DSN\r\n250 STARTTLS" featureSet := "250-DSN\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil) echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) { props := &serverProps{
if err := simpleSMTPServer(ctx, t, &serverProps{ EchoBuffer: echoBuffer,
EchoBuffer: buf,
FeatureSet: featureSet, FeatureSet: featureSet,
ListenPort: serverPort, ListenPort: serverPort,
}, }
); err != nil { go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err) t.Errorf("failed to start test server: %s", err)
return return
} }
}(echoBuffer) }()
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 30)
client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort))
if err != nil { if err != nil {
@ -2519,7 +2530,9 @@ func TestClient_Rcpt(t *testing.T) {
t.Error("recpient address with newlines should fail") t.Error("recpient address with newlines should fail")
} }
expected := "RCPT TO:<valid-to@domain.tld> NOTIFY=SUCCESS" expected := "RCPT TO:<valid-to@domain.tld> NOTIFY=SUCCESS"
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n") resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if !strings.EqualFold(resp[5], expected) { if !strings.EqualFold(resp[5], expected) {
t.Errorf("expected rcpt to command to be %q, but sent %q", expected, resp[5]) t.Errorf("expected rcpt to command to be %q, but sent %q", expected, resp[5])
} }
@ -2782,17 +2795,17 @@ func TestSendMail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load()) serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil) echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) { props := &serverProps{
if err := simpleSMTPServer(ctx, t, &serverProps{ EchoBuffer: echoBuffer,
EchoBuffer: buf,
FeatureSet: featureSet, FeatureSet: featureSet,
ListenPort: serverPort, ListenPort: serverPort,
}, }
); err != nil { go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err) t.Errorf("failed to start test server: %s", err)
return return
} }
}(echoBuffer) }()
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 30)
addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort)
testHookStartTLS = func(config *tls.Config) { testHookStartTLS = func(config *tls.Config) {
@ -2806,7 +2819,9 @@ func TestSendMail(t *testing.T) {
[]byte("test message")); err != nil { []byte("test message")); err != nil {
t.Fatalf("failed to send mail: %s", err) t.Fatalf("failed to send mail: %s", err)
} }
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n") resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if len(resp)-1 != len(want) { if len(resp)-1 != len(want) {
t.Fatalf("expected %d lines, but got %d", len(want), len(resp)) t.Fatalf("expected %d lines, but got %d", len(want), len(resp))
} }
@ -2857,17 +2872,17 @@ func TestSendMail(t *testing.T) {
serverPort := int(TestServerPortBase + PortAdder.Load()) serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS"
echoBuffer := bytes.NewBuffer(nil) echoBuffer := bytes.NewBuffer(nil)
go func(buf *bytes.Buffer) { props := &serverProps{
if err := simpleSMTPServer(ctx, t, &serverProps{ EchoBuffer: echoBuffer,
EchoBuffer: buf,
FeatureSet: featureSet, FeatureSet: featureSet,
ListenPort: serverPort, ListenPort: serverPort,
}, }
); err != nil { go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err) t.Errorf("failed to start test server: %s", err)
return return
} }
}(echoBuffer) }()
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 30)
addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort)
testHookStartTLS = func(config *tls.Config) { testHookStartTLS = func(config *tls.Config) {
@ -2887,7 +2902,9 @@ Goodbye.`)
if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, message); err != nil { if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, message); err != nil {
t.Fatalf("failed to send mail: %s", err) t.Fatalf("failed to send mail: %s", err)
} }
props.BufferMutex.RLock()
resp := strings.Split(echoBuffer.String(), "\r\n") resp := strings.Split(echoBuffer.String(), "\r\n")
props.BufferMutex.RUnlock()
if len(resp)-1 != len(want) { if len(resp)-1 != len(want) {
t.Errorf("expected %d lines, but got %d", len(want), len(resp)) t.Errorf("expected %d lines, but got %d", len(want), len(resp))
} }
@ -3542,6 +3559,7 @@ func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "
// serverProps represents the configuration properties for the SMTP server. // serverProps represents the configuration properties for the SMTP server.
type serverProps struct { type serverProps struct {
BufferMutex sync.RWMutex
EchoBuffer io.Writer EchoBuffer io.Writer
FailOnAuth bool FailOnAuth bool
FailOnDataInit bool FailOnDataInit bool
@ -3640,9 +3658,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
t.Logf("failed to write line: %s", err) t.Logf("failed to write line: %s", err)
} }
if props.EchoBuffer != nil { if props.EchoBuffer != nil {
if _, err := props.EchoBuffer.Write([]byte(data + "\r\n")); err != nil { props.BufferMutex.Lock()
t.Errorf("failed write to echo buffer: %s", err) if _, berr := props.EchoBuffer.Write([]byte(data + "\r\n")); berr != nil {
t.Errorf("failed write to echo buffer: %s", berr)
} }
props.BufferMutex.Unlock()
} }
_ = writer.Flush() _ = writer.Flush()
} }
@ -3665,9 +3685,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
} }
time.Sleep(time.Millisecond) time.Sleep(time.Millisecond)
if props.EchoBuffer != nil { if props.EchoBuffer != nil {
props.BufferMutex.Lock()
if _, berr := props.EchoBuffer.Write([]byte(data)); berr != nil { if _, berr := props.EchoBuffer.Write([]byte(data)); berr != nil {
t.Errorf("failed write to echo buffer: %s", berr) t.Errorf("failed write to echo buffer: %s", berr)
} }
props.BufferMutex.Unlock()
} }
var datastring string var datastring string
@ -3768,9 +3790,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server
break break
} }
if props.EchoBuffer != nil { if props.EchoBuffer != nil {
if _, err := props.EchoBuffer.Write([]byte(ddata)); err != nil { props.BufferMutex.Lock()
t.Errorf("failed write to echo buffer: %s", err) if _, berr := props.EchoBuffer.Write([]byte(ddata)); berr != nil {
t.Errorf("failed write to echo buffer: %s", berr)
} }
props.BufferMutex.Unlock()
} }
ddata = strings.TrimSpace(ddata) ddata = strings.TrimSpace(ddata)
if ddata == "." { if ddata == "." {