From 800c266ccba97415266df0ccc0cbdf0e71cabcc3 Mon Sep 17 00:00:00 2001 From: Winni Neessen Date: Mon, 11 Nov 2024 19:53:14 +0100 Subject: [PATCH] Refactor test server code for thread safety Moved serverProps outside goroutines to improve code readability and maintainability. Added a RWMutex to serverProps to ensure thread-safe access to EchoBuffer, preventing race conditions during concurrent writes. --- smtp/smtp_test.go | 160 ++++++++++++++++++++++++++-------------------- 1 file changed, 92 insertions(+), 68 deletions(-) diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 15ea7c5..931b74e 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -31,6 +31,7 @@ import ( "net" "os" "strings" + "sync" "sync/atomic" "testing" "time" @@ -2235,17 +2236,17 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-8BITMIME\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2261,7 +2262,9 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: BODY=8BITMIME" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[5], expected) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } @@ -2273,17 +2276,17 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-SMTPUTF8\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2299,7 +2302,9 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: SMTPUTF8" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[5], expected) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } @@ -2311,17 +2316,17 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-SMTPUTF8\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2337,7 +2342,9 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: SMTPUTF8" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[5], expected) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } @@ -2349,17 +2356,17 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2376,7 +2383,9 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: RET=FULL" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[5], expected) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) } @@ -2388,17 +2397,17 @@ func TestClient_Mail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) @@ -2415,7 +2424,9 @@ func TestClient_Mail(t *testing.T) { t.Errorf("failed to set mail from address: %s", err) } expected := "MAIL FROM: BODY=8BITMIME SMTPUTF8 RET=FULL" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[7], expected) { t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[7]) } @@ -2490,17 +2501,17 @@ func TestClient_Rcpt(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { @@ -2519,7 +2530,9 @@ func TestClient_Rcpt(t *testing.T) { t.Error("recpient address with newlines should fail") } expected := "RCPT TO: NOTIFY=SUCCESS" + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if !strings.EqualFold(resp[5], expected) { t.Errorf("expected rcpt to command to be %q, but sent %q", expected, resp[5]) } @@ -2782,17 +2795,17 @@ func TestSendMail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) testHookStartTLS = func(config *tls.Config) { @@ -2806,7 +2819,9 @@ func TestSendMail(t *testing.T) { []byte("test message")); err != nil { t.Fatalf("failed to send mail: %s", err) } + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if len(resp)-1 != len(want) { t.Fatalf("expected %d lines, but got %d", len(want), len(resp)) } @@ -2857,17 +2872,17 @@ func TestSendMail(t *testing.T) { serverPort := int(TestServerPortBase + PortAdder.Load()) featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" echoBuffer := bytes.NewBuffer(nil) - go func(buf *bytes.Buffer) { - if err := simpleSMTPServer(ctx, t, &serverProps{ - EchoBuffer: buf, - FeatureSet: featureSet, - ListenPort: serverPort, - }, - ); err != nil { + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { t.Errorf("failed to start test server: %s", err) return } - }(echoBuffer) + }() time.Sleep(time.Millisecond * 30) addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) testHookStartTLS = func(config *tls.Config) { @@ -2887,7 +2902,9 @@ Goodbye.`) if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, message); err != nil { t.Fatalf("failed to send mail: %s", err) } + props.BufferMutex.RLock() resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() if len(resp)-1 != len(want) { t.Errorf("expected %d lines, but got %d", len(want), len(resp)) } @@ -3542,6 +3559,7 @@ func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", " // serverProps represents the configuration properties for the SMTP server. type serverProps struct { + BufferMutex sync.RWMutex EchoBuffer io.Writer FailOnAuth bool FailOnDataInit bool @@ -3640,9 +3658,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server t.Logf("failed to write line: %s", err) } if props.EchoBuffer != nil { - if _, err := props.EchoBuffer.Write([]byte(data + "\r\n")); err != nil { - t.Errorf("failed write to echo buffer: %s", err) + props.BufferMutex.Lock() + if _, berr := props.EchoBuffer.Write([]byte(data + "\r\n")); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) } + props.BufferMutex.Unlock() } _ = writer.Flush() } @@ -3665,9 +3685,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server } time.Sleep(time.Millisecond) if props.EchoBuffer != nil { + props.BufferMutex.Lock() if _, berr := props.EchoBuffer.Write([]byte(data)); berr != nil { t.Errorf("failed write to echo buffer: %s", berr) } + props.BufferMutex.Unlock() } var datastring string @@ -3768,9 +3790,11 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server break } if props.EchoBuffer != nil { - if _, err := props.EchoBuffer.Write([]byte(ddata)); err != nil { - t.Errorf("failed write to echo buffer: %s", err) + props.BufferMutex.Lock() + if _, berr := props.EchoBuffer.Write([]byte(ddata)); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) } + props.BufferMutex.Unlock() } ddata = strings.TrimSpace(ddata) if ddata == "." {