diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 9dc3c10..6281666 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1793,6 +1793,80 @@ func TestClient_StartTLS(t *testing.T) { }) } +func TestClient_TLSConnectionState(t *testing.T) { + t.Run("normal TLS connection should return a state", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12} + if err = client.StartTLS(tlsConfig); err != nil { + t.Errorf("failed to initialize STARTTLS session: %s", err) + } + state, ok := client.TLSConnectionState() + if !ok { + t.Errorf("failed to get TLS connection state") + } + if state.Version < tls.VersionTLS12 { + t.Errorf("TLS connection state version is %d, should be >= %d", state.Version, tls.VersionTLS12) + } + }) + t.Run("no TLS state on non-TLS connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + _, ok := client.TLSConnectionState() + if ok { + t.Error("non-TLS connection should not have TLS connection state") + } + }) +} + // Issue 17794: don't send a trailing space on AUTH command when there's no password. func TestClient_Auth_trimSpace(t *testing.T) { server := "220 hello world\r\n" +