diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 48a4515..3fb84de 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -3396,6 +3396,122 @@ func TestClient_UpdateDeadline(t *testing.T) { }) } +func TestClient_GetTLSConnectionState(t *testing.T) { + t.Run("get state on sane client connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := getTLSConfig(t) + tlsConfig.MinVersion = tls.VersionTLS12 + tlsConfig.MaxVersion = tls.VersionTLS12 + if err = client.StartTLS(tlsConfig); err != nil { + t.Fatalf("failed to start TLS on client: %s", err) + } + state, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + if state == nil { + t.Error("expected TLS connection state to be non-nil") + } + if state.Version != tls.VersionTLS12 { + t.Errorf("expected TLS connection state version to be %d, got: %d", tls.VersionTLS12, state.Version) + } + }) + t.Run("get state on no connection", func(t *testing.T) { + client := &Client{} + _, err := client.GetTLSConnectionState() + if err == nil { + t.Fatal("expected client to have no tls connection state") + } + }) + t.Run("get state on non-tls client connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 DSN" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + _, err = client.GetTLSConnectionState() + if err == nil { + t.Error("expected client to have no tls connection state") + } + }) + t.Run("fail to get state on non-tls connection with tls flag set", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 DSN" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + client.tls = true + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + _, err = client.GetTLSConnectionState() + if err == nil { + t.Error("expected client to have no tls connection state") + } + }) +} + // faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter