diff --git a/sime.go b/smime.go similarity index 81% rename from sime.go rename to smime.go index de2decc..e92d283 100644 --- a/sime.go +++ b/smime.go @@ -16,9 +16,6 @@ var ( // ErrInvalidKeyPair should be used if key pair is invalid ErrInvalidKeyPair = errors.New("invalid key pair") - // ErrInvalidCertificate should be used if a certificate is invalid - ErrInvalidCertificate = errors.New("invalid certificate") - // ErrCouldNotInitialize should be used if the signed data could not initialize ErrCouldNotInitialize = errors.New("could not initialize signed data") @@ -34,9 +31,8 @@ var ( // SMime is used to sign messages with S/MIME type SMime struct { - privateKey *rsa.PrivateKey - certificate *x509.Certificate - parentCertificates []*x509.Certificate + privateKey *rsa.PrivateKey + certificate *x509.Certificate } // NewSMime construct a new instance of SMime with a provided *tls.Certificate @@ -45,19 +41,9 @@ func newSMime(keyPair *tls.Certificate) (*SMime, error) { return nil, ErrInvalidKeyPair } - parentCertificates := make([]*x509.Certificate, 0) - for _, cert := range keyPair.Certificate[1:] { - c, err := x509.ParseCertificate(cert) - if err != nil { - return nil, ErrInvalidCertificate - } - parentCertificates = append(parentCertificates, c) - } - return &SMime{ - privateKey: keyPair.PrivateKey.(*rsa.PrivateKey), - certificate: keyPair.Leaf, - parentCertificates: parentCertificates, + privateKey: keyPair.PrivateKey.(*rsa.PrivateKey), + certificate: keyPair.Leaf, }, nil } @@ -72,7 +58,7 @@ func (sm *SMime) signMessage(message string) (*string, error) { return nil, ErrCouldNotInitialize } - if err = signedData.AddSignerChain(sm.certificate, sm.privateKey, sm.parentCertificates, pkcs7.SignerInfoConfig{}); err != nil { + if err = signedData.AddSigner(sm.certificate, sm.privateKey, pkcs7.SignerInfoConfig{}); err != nil { return nil, ErrCouldNotAddSigner } @@ -92,7 +78,7 @@ func (sm *SMime) signMessage(message string) (*string, error) { } // createMessage prepares the message that will be used for the sign method later -func (sm *SMime) createMessage(encoding Encoding, contentType ContentType, charset Charset, body []byte) string { +func (sm *SMime) prepareMessage(encoding Encoding, contentType ContentType, charset Charset, body []byte) string { return fmt.Sprintf("Content-Transfer-Encoding: %v\r\nContent-Type: %v; charset=%v\r\n\r\n%v", encoding, contentType, charset, string(body)) } diff --git a/smime_test.go b/smime_test.go index bd0282d..cfd8518 100644 --- a/smime_test.go +++ b/smime_test.go @@ -1,7 +1,8 @@ package mail import ( - "errors" + "bytes" + "encoding/base64" "fmt" "strings" "testing" @@ -25,9 +26,6 @@ func TestNewSMime(t *testing.T) { if sMime.certificate != keyPair.Leaf { t.Errorf("NewSMime() did not return the same leaf certificate") } - if len(sMime.parentCertificates) != len(keyPair.Certificate)-1 { - t.Errorf("NewSMime() did not return the same number of parentCertificates") - } } // TestSign tests the sign method @@ -41,11 +39,20 @@ func TestSign(t *testing.T) { if err != nil { t.Errorf("Error creating new SMime from keyPair: %s", err) } - fmt.Println(sMime) + + message := "This is a test message" + singedMessage, err := sMime.signMessage(message) + if err != nil { + t.Errorf("Error creating singed message: %s", err) + } + + if *singedMessage == message { + t.Errorf("Sign() did not work") + } } -// TestCreateMessage tests the createMessage method -func TestCreateMessage(t *testing.T) { +// TestPrepareMessage tests the createMessage method +func TestPrepareMessage(t *testing.T) { keyPair, err := getDummyCertificate() if err != nil { t.Errorf("Error getting dummy certificate: %s", err) @@ -60,7 +67,7 @@ func TestCreateMessage(t *testing.T) { contentType := TypeTextPlain charset := CharsetUTF8 body := []byte("This is the body!") - result := sMime.createMessage(encoding, contentType, body) + result := sMime.prepareMessage(encoding, contentType, charset, body) if !strings.Contains(result, encoding.String()) { t.Errorf("createMessage() did not return the correct encoding") @@ -71,86 +78,113 @@ func TestCreateMessage(t *testing.T) { if !strings.Contains(result, string(body)) { t.Errorf("createMessage() did not return the correct body") } - if result != fmt.Sprintf("Content-Transfer-Encoding: %v\r\nContent-Type: %v; charset=%v\r\n\r\n%v", encoding, contentType, charset, string(body)) { + if result != fmt.Sprintf("Content-Transfer-Encoding: %s\r\nContent-Type: %s; charset=%s\r\n\r\n%s", encoding, contentType, charset, string(body)) { t.Errorf("createMessage() did not sucessfully create the message") } } // TestEncodeToPEM tests the encodeToPEM method func TestEncodeToPEM(t *testing.T) { + message := []byte("This is a test message") - keyPair, err := getDummyCertificate() + pemMessage, err := encodeToPEM(message) if err != nil { - t.Errorf("Error getting dummy certificate: %s", err) + t.Errorf("Error encoding message: %s", err) } - sMime, err := newSMime(keyPair) - if err != nil { - t.Errorf("Error creating new SMime from keyPair: %s", err) + base64Encoded := base64.StdEncoding.EncodeToString(message) + if *pemMessage != base64Encoded { + t.Errorf("encodeToPEM() did not work") } - fmt.Println(sMime) } // TestBytesFromLines tests the bytesFromLines method func TestBytesFromLines(t *testing.T) { + ls := lines{ + {line: []byte("Hello"), endOfLine: []byte("\n")}, + {line: []byte("World"), endOfLine: []byte("\n")}, + } + expected := []byte("Hello\nWorld\n") + result := ls.bytesFromLines([]byte("\n")) + if !bytes.Equal(result, expected) { + t.Errorf("Expected %s, but got %s", expected, result) + } +} + +// FuzzBytesFromLines tests the bytesFromLines method with fuzzing +func FuzzBytesFromLines(f *testing.F) { + f.Add([]byte("Hello"), []byte("\n")) + f.Fuzz(func(t *testing.T, lineData, sep []byte) { + ls := lines{ + {line: lineData, endOfLine: sep}, + } + _ = ls.bytesFromLines(sep) + }) } // TestParseLines tests the parseLines method func TestParseLines(t *testing.T) { + input := []byte("Hello\r\nWorld\nHello\rWorld") + expected := lines{ + {line: []byte("Hello"), endOfLine: []byte("\r\n")}, + {line: []byte("World"), endOfLine: []byte("\n")}, + {line: []byte("Hello"), endOfLine: []byte("\r")}, + {line: []byte("World"), endOfLine: []byte("")}, + } + result := parseLines(input) + if len(result) != len(expected) { + t.Errorf("Expected %d lines, but got %d", len(expected), len(result)) + } + + for i := range result { + if !bytes.Equal(result[i].line, expected[i].line) || !bytes.Equal(result[i].endOfLine, expected[i].endOfLine) { + t.Errorf("Line %d mismatch. Expected line: %s, endOfLine: %s, got line: %s, endOfLine: %s", + i, expected[i].line, expected[i].endOfLine, result[i].line, result[i].endOfLine) + } + } +} + +// FuzzParseLines tests the parseLines method with fuzzing +func FuzzParseLines(f *testing.F) { + f.Add([]byte("Hello\nWorld\r\nAnother\rLine")) + f.Fuzz(func(t *testing.T, input []byte) { + _ = parseLines(input) + }) } // TestSplitLine tests the splitLine method func TestSplitLine(t *testing.T) { - -} - -func foo(t *testing.T) { - tl := []struct { - n string - r SendErrReason - te bool - }{ - {"ErrGetSender/temp", ErrGetSender, true}, - {"ErrGetSender/perm", ErrGetSender, false}, - {"ErrGetRcpts/temp", ErrGetRcpts, true}, - {"ErrGetRcpts/perm", ErrGetRcpts, false}, - {"ErrSMTPMailFrom/temp", ErrSMTPMailFrom, true}, - {"ErrSMTPMailFrom/perm", ErrSMTPMailFrom, false}, - {"ErrSMTPRcptTo/temp", ErrSMTPRcptTo, true}, - {"ErrSMTPRcptTo/perm", ErrSMTPRcptTo, false}, - {"ErrSMTPData/temp", ErrSMTPData, true}, - {"ErrSMTPData/perm", ErrSMTPData, false}, - {"ErrSMTPDataClose/temp", ErrSMTPDataClose, true}, - {"ErrSMTPDataClose/perm", ErrSMTPDataClose, false}, - {"ErrSMTPReset/temp", ErrSMTPReset, true}, - {"ErrSMTPReset/perm", ErrSMTPReset, false}, - {"ErrWriteContent/temp", ErrWriteContent, true}, - {"ErrWriteContent/perm", ErrWriteContent, false}, - {"ErrConnCheck/temp", ErrConnCheck, true}, - {"ErrConnCheck/perm", ErrConnCheck, false}, - {"ErrNoUnencoded/temp", ErrNoUnencoded, true}, - {"ErrNoUnencoded/perm", ErrNoUnencoded, false}, - {"ErrAmbiguous/temp", ErrAmbiguous, true}, - {"ErrAmbiguous/perm", ErrAmbiguous, false}, - {"Unknown/temp", 9999, true}, - {"Unknown/perm", 9999, false}, + ls := lines{ + {line: []byte("Hello\r\nWorld\r\nAnotherLine"), endOfLine: []byte("")}, + } + expected := lines{ + {line: []byte("Hello"), endOfLine: []byte("\r\n")}, + {line: []byte("World"), endOfLine: []byte("\r\n")}, + {line: []byte("AnotherLine"), endOfLine: []byte("")}, } - for _, tt := range tl { - t.Run(tt.n, func(t *testing.T) { - if err := returnSendError(tt.r, tt.te); err != nil { - exp := &SendError{Reason: tt.r, isTemp: tt.te} - if !errors.Is(err, exp) { - t.Errorf("error mismatch, expected: %s (temp: %t), got: %s (temp: %t)", tt.r, tt.te, - exp.Error(), exp.isTemp) - } - if !strings.Contains(fmt.Sprintf("%s", err), tt.r.String()) { - t.Errorf("error string mismatch, expected: %s, got: %s", - tt.r.String(), fmt.Sprintf("%s", err)) - } - } - }) + result := ls.splitLine([]byte("\r\n")) + if len(result) != len(expected) { + t.Errorf("Expected %d lines, but got %d", len(expected), len(result)) + } + + for i := range result { + if !bytes.Equal(result[i].line, expected[i].line) || !bytes.Equal(result[i].endOfLine, expected[i].endOfLine) { + t.Errorf("Line %d mismatch. Expected line: %s, endOfLine: %s, got line: %s, endOfLine: %s", + i, expected[i].line, expected[i].endOfLine, result[i].line, result[i].endOfLine) + } } } + +// FuzzSplitLine tests the parseLsplitLineines method with fuzzing +func FuzzSplitLine(f *testing.F) { + f.Add([]byte("Hello\r\nWorld"), []byte("\r\n")) + f.Fuzz(func(t *testing.T, input, sep []byte) { + ls := lines{ + {line: input, endOfLine: []byte("")}, + } + _ = ls.splitLine(sep) + }) +}