go-mail/msgwriter.go
Winni Neessen 263f6bb3de
Refactor SMTP client code for better readability
The variable names in the code related to the I/O of the SMTP client have been clarified for improved readability and comprehension. For example, unclear variable names like `d` and `w` have been replaced with more meaningful names like `depth` and `writer`. The same naming improvements have also been applied to function parameters. This update aims to enhance code maintenance and simplify future development processes.
2024-02-24 12:43:01 +01:00

379 lines
9.5 KiB
Go

// SPDX-FileCopyrightText: 2022-2023 The go-mail Authors
//
// SPDX-License-Identifier: MIT
package mail
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"mime"
"mime/multipart"
"mime/quotedprintable"
"net/textproto"
"path/filepath"
"sort"
"strings"
)
// MaxHeaderLength defines the maximum line length for a mail header
// RFC 2047 suggests 76 characters
const MaxHeaderLength = 76
// MaxBodyLength defines the maximum line length for the mail body
// RFC 2047 suggests 76 characters
const MaxBodyLength = 76
// SingleNewLine represents a new line that can be used by the msgWriter to issue a carriage return
const SingleNewLine = "\r\n"
// DoubleNewLine represents a double new line that can be used by the msgWriter to
// indicate a new segement of the mail
const DoubleNewLine = "\r\n\r\n"
// msgWriter handles the I/O to the io.WriteCloser of the SMTP client
type msgWriter struct {
bytesWritten int64
charset Charset
depth int8
encoder mime.WordEncoder
err error
multiPartWriter [3]*multipart.Writer
partWriter io.Writer
writer io.Writer
}
// Write implements the io.Writer interface for msgWriter
func (mw *msgWriter) Write(payload []byte) (int, error) {
if mw.err != nil {
return 0, fmt.Errorf("failed to write due to previous error: %w", mw.err)
}
var n int
n, mw.err = mw.writer.Write(payload)
mw.bytesWritten += int64(n)
return n, mw.err
}
// writeMsg formats the message and sends it to its io.Writer
func (mw *msgWriter) writeMsg(msg *Msg) {
msg.addDefaultHeader()
msg.checkUserAgent()
mw.writeGenHeader(msg)
mw.writePreformattedGenHeader(msg)
// Set the FROM header (or envelope FROM if FROM is empty)
hasFrom := true
from, ok := msg.addrHeader[HeaderFrom]
if !ok || (len(from) == 0 || from == nil) {
from, ok = msg.addrHeader[HeaderEnvelopeFrom]
if !ok || (len(from) == 0 || from == nil) {
hasFrom = false
}
}
if hasFrom && (len(from) > 0 && from[0] != nil) {
mw.writeHeader(Header(HeaderFrom), from[0].String())
}
// Set the rest of the address headers
for _, to := range []AddrHeader{HeaderTo, HeaderCc} {
if addresses, ok := msg.addrHeader[to]; ok {
var val []string
for _, addr := range addresses {
val = append(val, addr.String())
}
mw.writeHeader(Header(to), val...)
}
}
if msg.hasMixed() {
mw.startMP("mixed", msg.boundary)
mw.writeString(DoubleNewLine)
}
if msg.hasRelated() {
mw.startMP("related", msg.boundary)
mw.writeString(DoubleNewLine)
}
if msg.hasAlt() {
mw.startMP(MIMEAlternative, msg.boundary)
mw.writeString(DoubleNewLine)
}
if msg.hasPGPType() {
switch msg.pgptype {
case PGPEncrypt:
mw.startMP(`encrypted; protocol="application/pgp-encrypted"`,
msg.boundary)
case PGPSignature:
mw.startMP(`signed; protocol="application/pgp-signature";`,
msg.boundary)
default:
}
mw.writeString(DoubleNewLine)
}
for _, part := range msg.parts {
if !part.del {
mw.writePart(part, msg.charset)
}
}
if msg.hasAlt() {
mw.stopMP()
}
// Add embeds
mw.addFiles(msg.embeds, false)
if msg.hasRelated() {
mw.stopMP()
}
// Add attachments
mw.addFiles(msg.attachments, true)
if msg.hasMixed() {
mw.stopMP()
}
}
// writeGenHeader writes out all generic headers to the msgWriter
func (mw *msgWriter) writeGenHeader(msg *Msg) {
keys := make([]string, 0, len(msg.genHeader))
for key := range msg.genHeader {
keys = append(keys, string(key))
}
sort.Strings(keys)
for _, key := range keys {
mw.writeHeader(Header(key), msg.genHeader[Header(key)]...)
}
}
// writePreformatedHeader writes out all preformated generic headers to the msgWriter
func (mw *msgWriter) writePreformattedGenHeader(msg *Msg) {
for key, val := range msg.preformHeader {
mw.writeString(fmt.Sprintf("%s: %s%s", key, val, SingleNewLine))
}
}
// startMP writes a multipart beginning
func (mw *msgWriter) startMP(mimeType MIMEType, boundary string) {
multiPartWriter := multipart.NewWriter(mw)
if boundary != "" {
mw.err = multiPartWriter.SetBoundary(boundary)
}
contentType := fmt.Sprintf("multipart/%s;\r\n boundary=%s", mimeType,
multiPartWriter.Boundary())
mw.multiPartWriter[mw.depth] = multiPartWriter
if mw.depth == 0 {
mw.writeString(fmt.Sprintf("%s: %s", HeaderContentType, contentType))
}
if mw.depth > 0 {
mw.newPart(map[string][]string{"Content-Type": {contentType}})
}
mw.depth++
}
// stopMP closes the multipart
func (mw *msgWriter) stopMP() {
if mw.depth > 0 {
mw.err = mw.multiPartWriter[mw.depth-1].Close()
mw.depth--
}
}
// addFiles adds the attachments/embeds file content to the mail body
func (mw *msgWriter) addFiles(files []*File, isAttachment bool) {
for _, file := range files {
encoding := EncodingB64
if _, ok := file.getHeader(HeaderContentType); !ok {
mimeType := mime.TypeByExtension(filepath.Ext(file.Name))
if mimeType == "" {
mimeType = "application/octet-stream"
}
if file.ContentType != "" {
mimeType = string(file.ContentType)
}
file.setHeader(HeaderContentType, fmt.Sprintf(`%s; name="%s"`, mimeType,
mw.encoder.Encode(mw.charset.String(), file.Name)))
}
if _, ok := file.getHeader(HeaderContentTransferEnc); !ok {
if file.Enc != "" {
encoding = file.Enc
}
file.setHeader(HeaderContentTransferEnc, string(encoding))
}
if file.Desc != "" {
if _, ok := file.getHeader(HeaderContentDescription); !ok {
file.setHeader(HeaderContentDescription, file.Desc)
}
}
if _, ok := file.getHeader(HeaderContentDisposition); !ok {
disposition := "inline"
if isAttachment {
disposition = "attachment"
}
file.setHeader(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`,
disposition, mw.encoder.Encode(mw.charset.String(), file.Name)))
}
if !isAttachment {
if _, ok := file.getHeader(HeaderContentID); !ok {
file.setHeader(HeaderContentID, fmt.Sprintf("<%s>", file.Name))
}
}
if mw.depth == 0 {
for header, val := range file.Header {
mw.writeHeader(Header(header), val...)
}
mw.writeString(SingleNewLine)
}
if mw.depth > 0 {
mw.newPart(file.Header)
}
if mw.err == nil {
mw.writeBody(file.Writer, encoding)
}
}
}
// newPart creates a new MIME multipart io.Writer and sets the partwriter to it
func (mw *msgWriter) newPart(header map[string][]string) {
mw.partWriter, mw.err = mw.multiPartWriter[mw.depth-1].CreatePart(header)
}
// writePart writes the corresponding part to the Msg body
func (mw *msgWriter) writePart(part *Part, charset Charset) {
partCharset := part.cset
if partCharset.String() == "" {
partCharset = charset
}
contentType := fmt.Sprintf("%s; charset=%s", part.ctype, partCharset)
contentTransferEnc := part.enc.String()
if mw.depth == 0 {
mw.writeHeader(HeaderContentType, contentType)
mw.writeHeader(HeaderContentTransferEnc, contentTransferEnc)
mw.writeString(SingleNewLine)
}
if mw.depth > 0 {
mimeHeader := textproto.MIMEHeader{}
if part.desc != "" {
mimeHeader.Add(string(HeaderContentDescription), part.desc)
}
mimeHeader.Add(string(HeaderContentType), contentType)
mimeHeader.Add(string(HeaderContentTransferEnc), contentTransferEnc)
mw.newPart(mimeHeader)
}
mw.writeBody(part.w, part.enc)
}
// writeString writes a string into the msgWriter's io.Writer interface
func (mw *msgWriter) writeString(s string) {
if mw.err != nil {
return
}
var n int
n, mw.err = io.WriteString(mw.writer, s)
mw.bytesWritten += int64(n)
}
// writeHeader writes a header into the msgWriter's io.Writer
func (mw *msgWriter) writeHeader(key Header, values ...string) {
wbuf := bytes.Buffer{}
cl := MaxHeaderLength - 2
wbuf.WriteString(string(key))
cl -= len(key)
if len(values) == 0 {
wbuf.WriteString(":\r\n")
return
}
wbuf.WriteString(": ")
cl -= 2
fs := strings.Join(values, ", ")
sfs := strings.Split(fs, " ")
for i, v := range sfs {
if cl-len(v) <= 1 {
wbuf.WriteString(fmt.Sprintf("%s ", SingleNewLine))
cl = MaxHeaderLength - 3
}
wbuf.WriteString(v)
if i < len(sfs)-1 {
wbuf.WriteString(" ")
cl -= 1
}
cl -= len(v)
}
bufs := wbuf.String()
bufs = strings.ReplaceAll(bufs, fmt.Sprintf(" %s", SingleNewLine), SingleNewLine)
mw.writeString(bufs)
mw.writeString("\r\n")
}
// writeBody writes an io.Reader into an io.Writer using provided Encoding
func (mw *msgWriter) writeBody(f func(io.Writer) (int64, error), e Encoding) {
var w io.Writer
var ew io.WriteCloser
var n int64
var err error
if mw.depth == 0 {
w = mw.writer
}
if mw.depth > 0 {
w = mw.partWriter
}
wbuf := bytes.Buffer{}
lb := Base64LineBreaker{}
lb.out = &wbuf
switch e {
case EncodingQP:
ew = quotedprintable.NewWriter(&wbuf)
case EncodingB64:
ew = base64.NewEncoder(base64.StdEncoding, &lb)
case NoEncoding:
_, err = f(&wbuf)
if err != nil {
mw.err = fmt.Errorf("bodyWriter function: %w", err)
}
n, err = io.Copy(w, &wbuf)
if err != nil && mw.err == nil {
mw.err = fmt.Errorf("bodyWriter io.Copy: %w", err)
}
if mw.depth == 0 {
mw.bytesWritten += n
}
return
default:
ew = quotedprintable.NewWriter(w)
}
_, err = f(ew)
if err != nil {
mw.err = fmt.Errorf("bodyWriter function: %w", err)
}
err = ew.Close()
if err != nil && mw.err == nil {
mw.err = fmt.Errorf("bodyWriter close encoded writer: %w", err)
}
err = lb.Close()
if err != nil && mw.err == nil {
mw.err = fmt.Errorf("bodyWriter close linebreaker: %w", err)
}
n, err = io.Copy(w, &wbuf)
if err != nil && mw.err == nil {
mw.err = fmt.Errorf("bodyWriter io.Copy: %w", err)
}
// Since the part writer uses the WriteTo() method, we don't need to add the
// bytes twice
if mw.depth == 0 {
mw.bytesWritten += n
}
}