Merge branch 'refs/heads/main' into feature/145_add-eml-parser-to-generate-msg-from-eml-files

This commit is contained in:
Winni Neessen 2024-05-16 15:17:47 +02:00
commit 2e548512c6
Signed by: wneessen
GPG key ID: 385AC9889632126E
34 changed files with 1521 additions and 1094 deletions

15
.github/dependabot.yml vendored Normal file
View file

@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: 2022-2023 The go-mail Authors
#
# SPDX-License-Identifier: CC0-1.0
version: 2
updates:
- package-ecosystem: github-actions
directory: /
schedule:
interval: daily
- package-ecosystem: gomod
directory: /
schedule:
interval: daily

View file

@ -27,6 +27,9 @@ env:
TEST_SMTPAUTH_USER: ${{ secrets.TEST_USER }} TEST_SMTPAUTH_USER: ${{ secrets.TEST_USER }}
TEST_SMTPAUTH_PASS: ${{ secrets.TEST_PASS }} TEST_SMTPAUTH_PASS: ${{ secrets.TEST_PASS }}
TEST_SMTPAUTH_TYPE: "LOGIN" TEST_SMTPAUTH_TYPE: "LOGIN"
permissions:
contents: read
jobs: jobs:
run: run:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
@ -35,10 +38,15 @@ jobs:
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
go: [1.18, 1.19, '1.20', '1.21', '1.22'] go: [1.18, 1.19, '1.20', '1.21', '1.22']
steps: steps:
- name: Harden Runner
uses: step-security/harden-runner@a4aa98b93cab29d9b1101a6143fb8bce00e2eac4 # v2.7.1
with:
egress-policy: audit
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@master uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master
- name: Setup go - name: Setup go
uses: actions/setup-go@v3 uses: actions/setup-go@cdcb36043654635271a94b9a6d1392de5bb323a7 # v5.0.1
with: with:
go-version: ${{ matrix.go }} go-version: ${{ matrix.go }}
- name: Install sendmail - name: Install sendmail
@ -50,6 +58,6 @@ jobs:
go test -v -race --coverprofile=coverage.coverprofile --covermode=atomic ./... go test -v -race --coverprofile=coverage.coverprofile --covermode=atomic ./...
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
if: success() && matrix.go == '1.22' && matrix.os == 'ubuntu-latest' if: success() && matrix.go == '1.22' && matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v3 uses: codecov/codecov-action@6d798873df2b1b8e5846dba6fb86631229fbcb17 # v4.4.0
with: with:
token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos

View file

@ -24,6 +24,9 @@ on:
schedule: schedule:
- cron: '37 23 * * 5' - cron: '37 23 * * 5'
permissions:
contents: read
jobs: jobs:
analyze: analyze:
name: Analyze name: Analyze
@ -41,12 +44,17 @@ jobs:
# Learn more about CodeQL language support at https://git.io/codeql-language-support # Learn more about CodeQL language support at https://git.io/codeql-language-support
steps: steps:
- name: Harden Runner
uses: step-security/harden-runner@a4aa98b93cab29d9b1101a6143fb8bce00e2eac4 # v2.7.1
with:
egress-policy: audit
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v2 uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0
# Initializes the CodeQL tools for scanning. # Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL - name: Initialize CodeQL
uses: github/codeql-action/init@v1 uses: github/codeql-action/init@b7cec7526559c32f1616476ff32d17ba4c59b2d6 # v3.25.5
with: with:
languages: ${{ matrix.language }} languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file. # If you wish to specify custom queries, you can do so here or in a config file.
@ -57,7 +65,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below) # If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild - name: Autobuild
uses: github/codeql-action/autobuild@v1 uses: github/codeql-action/autobuild@b7cec7526559c32f1616476ff32d17ba4c59b2d6 # v3.25.5
# Command-line programs to run using the OS shell. # Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl # 📚 https://git.io/JvXDl
@ -71,4 +79,4 @@ jobs:
# make release # make release
- name: Perform CodeQL Analysis - name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v1 uses: github/codeql-action/analyze@b7cec7526559c32f1616476ff32d17ba4c59b2d6 # v3.25.5

31
.github/workflows/dependency-review.yml vendored Normal file
View file

@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: 2022-2023 The go-mail Authors
#
# SPDX-License-Identifier: CC0-1.0
# Dependency Review Action
#
# This Action will scan dependency manifest files that change as part of a Pull Request,
# surfacing known-vulnerable versions of the packages declared or updated in the PR.
# Once installed, if the workflow run is marked as required,
# PRs introducing known-vulnerable packages will be blocked from merging.
#
# Source repository: https://github.com/actions/dependency-review-action
name: 'Dependency Review'
on: [pull_request]
permissions:
contents: read
jobs:
dependency-review:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@a4aa98b93cab29d9b1101a6143fb8bce00e2eac4 # v2.7.1
with:
egress-policy: audit
- name: 'Checkout Repository'
uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0
- name: 'Dependency Review'
uses: actions/dependency-review-action@0c155c5e8556a497adf53f2c18edabf945ed8e70 # v4.3.2

View file

@ -19,12 +19,17 @@ jobs:
name: lint name: lint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/setup-go@v3 - name: Harden Runner
uses: step-security/harden-runner@a4aa98b93cab29d9b1101a6143fb8bce00e2eac4 # v2.7.1
with:
egress-policy: audit
- uses: actions/setup-go@cdcb36043654635271a94b9a6d1392de5bb323a7 # v5.0.1
with: with:
go-version: '1.22' go-version: '1.22'
- uses: actions/checkout@v3 - uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0
- name: golangci-lint - name: golangci-lint
uses: golangci/golangci-lint-action@v3 uses: golangci/golangci-lint-action@a4f60bb28d35aeee14e6880718e0c85ff1882e64 # v6.0.1
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: latest version: latest

21
.github/workflows/govulncheck.yml vendored Normal file
View file

@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: 2022 Winni Neessen <winni@neessen.dev>
#
# SPDX-License-Identifier: CC0-1.0
name: Govulncheck Security Scan
on: [push, pull_request]
permissions:
contents: read
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@a4aa98b93cab29d9b1101a6143fb8bce00e2eac4 # v2.7.1
with:
egress-policy: audit
- name: Run govulncheck
uses: golang/govulncheck-action@3a32958c2706f7048305d5a2e53633d7e37e97d0 # v1.0.2

View file

@ -6,10 +6,18 @@ name: REUSE Compliance Check
on: [push, pull_request] on: [push, pull_request]
permissions:
contents: read
jobs: jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - name: Harden Runner
uses: step-security/harden-runner@a4aa98b93cab29d9b1101a6143fb8bce00e2eac4 # v2.7.1
with:
egress-policy: audit
- uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0
- name: REUSE Compliance Check - name: REUSE Compliance Check
uses: fsfe/reuse-action@v1 uses: fsfe/reuse-action@a46482ca367aef4454a87620aa37c2be4b2f8106 # v3.0.0

80
.github/workflows/scorecards.yml vendored Normal file
View file

@ -0,0 +1,80 @@
# SPDX-FileCopyrightText: 2022-2023 The go-mail Authors
#
# SPDX-License-Identifier: CC0-1.0
# This workflow uses actions that are not certified by GitHub. They are provided
# by a third-party and are governed by separate terms of service, privacy
# policy, and support documentation.
name: Scorecard supply-chain security
on:
# For Branch-Protection check. Only the default branch is supported. See
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection
branch_protection_rule:
# To guarantee Maintained check is occasionally updated. See
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained
schedule:
- cron: '20 7 * * 2'
push:
branches: ["main"]
# Declare default permissions as read only.
permissions: read-all
jobs:
analysis:
name: Scorecard analysis
runs-on: ubuntu-latest
permissions:
# Needed to upload the results to code-scanning dashboard.
security-events: write
# Needed to publish results and get a badge (see publish_results below).
id-token: write
contents: read
actions: read
steps:
- name: Harden Runner
uses: step-security/harden-runner@a4aa98b93cab29d9b1101a6143fb8bce00e2eac4 # v2.7.1
with:
egress-policy: audit
- name: "Checkout code"
uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0
with:
persist-credentials: false
- name: "Run analysis"
uses: ossf/scorecard-action@dc50aa9510b46c811795eb24b2f1ba02a914e534 # v2.3.3
with:
results_file: results.sarif
results_format: sarif
# (Optional) "write" PAT token. Uncomment the `repo_token` line below if:
# - you want to enable the Branch-Protection check on a *public* repository, or
# - you are installing Scorecards on a *private* repository
# To create the PAT, follow the steps in https://github.com/ossf/scorecard-action#authentication-with-pat.
# repo_token: ${{ secrets.SCORECARD_TOKEN }}
# Public repositories:
# - Publish results to OpenSSF REST API for easy access by consumers
# - Allows the repository to include the Scorecard badge.
# - See https://github.com/ossf/scorecard-action#publishing-results.
# For private repositories:
# - `publish_results` will always be set to `false`, regardless
# of the value entered here.
publish_results: true
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
# format to the repository Actions tab.
- name: "Upload artifact"
uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3
with:
name: SARIF file
path: results.sarif
retention-days: 5
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@b7cec7526559c32f1616476ff32d17ba4c59b2d6 # v3.25.5
with:
sarif_file: results.sarif

View file

@ -3,6 +3,10 @@
# SPDX-License-Identifier: CC0-1.0 # SPDX-License-Identifier: CC0-1.0
name: SonarQube name: SonarQube
permissions:
contents: read
on: on:
push: push:
branches: branches:
@ -22,12 +26,17 @@ jobs:
name: Build name: Build
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - name: Harden Runner
uses: step-security/harden-runner@a4aa98b93cab29d9b1101a6143fb8bce00e2eac4 # v2.7.1
with:
egress-policy: audit
- uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Setup Go - name: Setup Go
uses: actions/setup-go@v3 uses: actions/setup-go@cdcb36043654635271a94b9a6d1392de5bb323a7 # v5.0.1
with: with:
go-version: '1.22.x' go-version: '1.22.x'
@ -35,12 +44,12 @@ jobs:
run: | run: |
go test -v -race --coverprofile=./cov.out ./... go test -v -race --coverprofile=./cov.out ./...
- uses: sonarsource/sonarqube-scan-action@master - uses: sonarsource/sonarqube-scan-action@53c3e3207fe4b8d52e2f1ac9d6eb1d2506f626c0 # master
env: env:
SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }}
SONAR_HOST_URL: ${{ secrets.SONAR_HOST_URL }} SONAR_HOST_URL: ${{ secrets.SONAR_HOST_URL }}
- uses: sonarsource/sonarqube-quality-gate-action@master - uses: sonarsource/sonarqube-quality-gate-action@72f24ebf1f81eda168a979ce14b8203273b7c3ad # master
timeout-minutes: 5 timeout-minutes: 5
env: env:
SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }}

2
.gitignore vendored
View file

@ -55,3 +55,5 @@ com_crashlytics_export_strings.xml
crashlytics.properties crashlytics.properties
crashlytics-build.properties crashlytics-build.properties
fabric.properties fabric.properties
testdata

View file

@ -12,6 +12,8 @@ SPDX-License-Identifier: CC0-1.0
[![Mentioned in Awesome Go](https://awesome.re/mentioned-badge-flat.svg)](https://github.com/avelino/awesome-go) [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge-flat.svg)](https://github.com/avelino/awesome-go)
[![#go-mail on Discord](https://img.shields.io/badge/Discord-%23go%E2%80%93mail-blue.svg)](https://discord.gg/ysQXkaccXk) [![#go-mail on Discord](https://img.shields.io/badge/Discord-%23go%E2%80%93mail-blue.svg)](https://discord.gg/ysQXkaccXk)
[![REUSE status](https://api.reuse.software/badge/github.com/wneessen/go-mail)](https://api.reuse.software/info/github.com/wneessen/go-mail) [![REUSE status](https://api.reuse.software/badge/github.com/wneessen/go-mail)](https://api.reuse.software/info/github.com/wneessen/go-mail)
[![OpenSSF Best Practices](https://www.bestpractices.dev/projects/8701/badge)](https://www.bestpractices.dev/projects/8701)
[![OpenSSF Scorecard](https://api.securityscorecards.dev/projects/github.com/wneessen/go-mail/badge)](https://securityscorecards.dev/viewer/?uri=github.com/wneessen/go-mail)
<a href="https://ko-fi.com/D1D24V9IX"><img src="https://uploads-ssl.webflow.com/5c14e387dab576fe667689cf/5cbed8a4ae2b88347c06c923_BuyMeACoffee_blue.png" height="20" alt="buy ma a coffee"></a> <a href="https://ko-fi.com/D1D24V9IX"><img src="https://uploads-ssl.webflow.com/5c14e387dab576fe667689cf/5cbed8a4ae2b88347c06c923_BuyMeACoffee_blue.png" height="20" alt="buy ma a coffee"></a>
<p align="center"><img src="./assets/gopher2.svg" width="250" alt="go-mail logo"/></p> <p align="center"><img src="./assets/gopher2.svg" width="250" alt="go-mail logo"/></p>
@ -83,6 +85,13 @@ alter a given mail message to their needs without relying on `go-mail` to suppor
To get our users started with message middleware, we've created a collection of useful middlewares. It can be To get our users started with message middleware, we've created a collection of useful middlewares. It can be
found in a seperate repository: [go-mail-middlware](https://github.com/wneessen/go-mail-middleware). found in a seperate repository: [go-mail-middlware](https://github.com/wneessen/go-mail-middleware).
## Merch
Thanks to our wonderful friends at [HelloTux](https://www.hellotux.com) we can offer great go-mail merchandising. All merch articles are embroidery
to provide the best and most long-lasting quality possible.
If you want to support the open source community and represent your favourite Go mail library with some cool drip, check out our merch shop at:
[https://www.hellotux.com/go-mail](https://www.hellotux.com/go-mail).
## Examples ## Examples
We provide example code in both our GoDocs as well as on our official Website (see [Documentation](#documentation)). For a quick start into go-mail We provide example code in both our GoDocs as well as on our official Website (see [Documentation](#documentation)). For a quick start into go-mail

View file

@ -20,39 +20,39 @@ type Base64LineBreaker struct {
out io.Writer out io.Writer
} }
var nl = []byte(SingleNewLine) var newlineBytes = []byte(SingleNewLine)
// Write writes the data stream and inserts a SingleNewLine when the maximum // Write writes the data stream and inserts a SingleNewLine when the maximum
// line length is reached // line length is reached
func (l *Base64LineBreaker) Write(b []byte) (n int, err error) { func (l *Base64LineBreaker) Write(data []byte) (numBytes int, err error) {
if l.out == nil { if l.out == nil {
err = fmt.Errorf(ErrNoOutWriter) err = fmt.Errorf(ErrNoOutWriter)
return return
} }
if l.used+len(b) < MaxBodyLength { if l.used+len(data) < MaxBodyLength {
copy(l.line[l.used:], b) copy(l.line[l.used:], data)
l.used += len(b) l.used += len(data)
return len(b), nil return len(data), nil
} }
n, err = l.out.Write(l.line[0:l.used]) numBytes, err = l.out.Write(l.line[0:l.used])
if err != nil { if err != nil {
return return
} }
excess := MaxBodyLength - l.used excess := MaxBodyLength - l.used
l.used = 0 l.used = 0
n, err = l.out.Write(b[0:excess]) numBytes, err = l.out.Write(data[0:excess])
if err != nil { if err != nil {
return return
} }
n, err = l.out.Write(nl) numBytes, err = l.out.Write(newlineBytes)
if err != nil { if err != nil {
return return
} }
return l.Write(b[excess:]) return l.Write(data[excess:])
} }
// Close closes the Base64LineBreaker and writes any access data that is still // Close closes the Base64LineBreaker and writes any access data that is still
@ -63,7 +63,7 @@ func (l *Base64LineBreaker) Close() (err error) {
if err != nil { if err != nil {
return return
} }
_, err = l.out.Write(nl) _, err = l.out.Write(newlineBytes)
} }
return return

View file

@ -5,8 +5,10 @@
package mail package mail
import ( import (
"bufio"
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -382,6 +384,11 @@ LjI4MiIgc3R5bGU9ImZpbGw6I2ZmYjI1YztzdHJva2U6IzAwMDtzdHJva2Utd2lkdGg6NC45NXB4
OyIvPjwvZz48L3N2Zz4= OyIvPjwvZz48L3N2Zz4=
` `
var (
errMockDefault = errors.New("mock write error")
errMockNewline = errors.New("mock newline error")
)
// TestBase64LineBreaker tests the Write and Close methods of the Base64LineBreaker // TestBase64LineBreaker tests the Write and Close methods of the Base64LineBreaker
func TestBase64LineBreaker(t *testing.T) { func TestBase64LineBreaker(t *testing.T) {
l, err := os.Open("assets/gopher2.svg") l, err := os.Open("assets/gopher2.svg")
@ -436,6 +443,47 @@ func TestBase64LineBreakerFailures(t *testing.T) {
} }
} }
func TestBase64LineBreaker_WriteAndClose(t *testing.T) {
tests := []struct {
name string
data []byte
writer io.Writer
}{
{
name: "Write data within MaxBodyLength",
data: []byte("testdata"),
writer: &mockWriterExcess{writeError: errMockDefault},
},
{
name: "Write data exceeds MaxBodyLength",
data: []byte("verylongtestdataverylongtestdataverylongtestdata" +
"verylongtestdataverylongtestdataverylongtestdata"),
writer: &mockWriterExcess{writeError: errMockDefault},
},
{
name: "Write data exceeds MaxBodyLength with newline",
data: []byte("verylongtestdataverylongtestdataverylongtestdata" +
"verylongtestdataverylongtestdataverylongtestdata"),
writer: &mockWriterNewline{writeError: errMockDefault},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
blr := &Base64LineBreaker{out: tt.writer}
_, err := blr.Write(tt.data)
if err != nil && !errors.Is(err, errMockDefault) && !errors.Is(err, errMockNewline) {
t.Errorf("Unexpected error while writing: %v", err)
}
err = blr.Close()
if err != nil && !errors.Is(err, errMockDefault) && !errors.Is(err, errMockNewline) {
t.Errorf("Unexpected error while closing: %v", err)
}
})
}
}
// removeNewLines removes any newline characters from the given data // removeNewLines removes any newline characters from the given data
func removeNewLines(data []byte) []byte { func removeNewLines(data []byte) []byte {
result := make([]byte, len(data)) result := make([]byte, len(data))
@ -461,3 +509,49 @@ func (e errorWriter) Write([]byte) (int, error) {
func (e errorWriter) Close() error { func (e errorWriter) Close() error {
return fmt.Errorf("supposed to always fail") return fmt.Errorf("supposed to always fail")
} }
type mockWriterExcess struct {
writeError error
}
type mockWriterNewline struct {
writeError error
}
func (w *mockWriterExcess) Write(p []byte) (n int, err error) {
switch len(p) {
case 0:
return 0, nil
case 2:
return 2, nil
default:
return len(p), errMockDefault
}
}
func (w *mockWriterNewline) Write(p []byte) (n int, err error) {
switch len(p) {
case 0:
return 0, nil
case 2:
return 2, errMockNewline
default:
return len(p), nil
}
}
func FuzzBase64LineBreaker_Write(f *testing.F) {
f.Add([]byte("abc"))
f.Add([]byte("def"))
f.Add([]uint8{0o0, 0o1, 0o2, 30, 255})
buf := bytes.Buffer{}
bw := bufio.NewWriter(&buf)
f.Fuzz(func(t *testing.T, data []byte) {
b := &Base64LineBreaker{out: bw}
if _, err := b.Write(data); err != nil {
t.Errorf("failed to write to B64LineBreaker: %s", err)
}
if err := b.Close(); err != nil {
t.Errorf("failed to close B64LineBreaker: %s", err)
}
})
}

343
client.go
View file

@ -87,11 +87,11 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con
// Client is the SMTP client struct // Client is the SMTP client struct
type Client struct { type Client struct {
// co is the net.Conn that the smtp.Client is based on // connection is the net.Conn that the smtp.Client is based on
co net.Conn connection net.Conn
// Timeout for the SMTP server connection // Timeout for the SMTP server connection
cto time.Duration connTimeout time.Duration
// dsn indicates that we want to use DSN for the Client // dsn indicates that we want to use DSN for the Client
dsn bool dsn bool
@ -102,8 +102,8 @@ type Client struct {
// dsnrntype defines the DSNRcptNotifyOption in case DSN is enabled // dsnrntype defines the DSNRcptNotifyOption in case DSN is enabled
dsnrntype []string dsnrntype []string
// enc indicates if a Client connection is encrypted or not // isEncrypted indicates if a Client connection is encrypted or not
enc bool isEncrypted bool
// noNoop indicates the Noop is to be skipped // noNoop indicates the Noop is to be skipped
noNoop bool noNoop bool
@ -121,17 +121,17 @@ type Client struct {
port int port int
fallbackPort int fallbackPort int
// sa is a pointer to smtp.Auth // smtpAuth is a pointer to smtp.Auth
sa smtp.Auth smtpAuth smtp.Auth
// satype represents the authentication type for SMTP AUTH // smtpAuthType represents the authentication type for SMTP AUTH
satype SMTPAuthType smtpAuthType SMTPAuthType
// sc is the smtp.Client that is set up when using the Dial*() methods // smtpClient is the smtp.Client that is set up when using the Dial*() methods
sc *smtp.Client smtpClient *smtp.Client
// Use SSL for the connection // Use SSL for the connection
ssl bool useSSL bool
// tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol // tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol
tlspolicy TLSPolicy tlspolicy TLSPolicy
@ -142,11 +142,11 @@ type Client struct {
// user is the SMTP AUTH username // user is the SMTP AUTH username
user string user string
// dl enables the debug logging on the SMTP client // useDebugLog enables the debug logging on the SMTP client
dl bool useDebugLog bool
// l is a logger that implements the log.Logger interface // logger is a logger that implements the log.Logger interface
l log.Logger logger log.Logger
// dialContextFunc is a custom DialContext function to dial target SMTP server // dialContextFunc is a custom DialContext function to dial target SMTP server
dialContextFunc DialContextFunc dialContextFunc DialContextFunc
@ -198,12 +198,12 @@ var (
) )
// NewClient returns a new Session client object // NewClient returns a new Session client object
func NewClient(h string, o ...Option) (*Client, error) { func NewClient(host string, opts ...Option) (*Client, error) {
c := &Client{ c := &Client{
cto: DefaultTimeout, connTimeout: DefaultTimeout,
host: h, host: host,
port: DefaultPort, port: DefaultPort,
tlsconfig: &tls.Config{ServerName: h, MinVersion: DefaultTLSMinVersion}, tlsconfig: &tls.Config{ServerName: host, MinVersion: DefaultTLSMinVersion},
tlspolicy: DefaultTLSPolicy, tlspolicy: DefaultTLSPolicy,
} }
@ -213,11 +213,11 @@ func NewClient(h string, o ...Option) (*Client, error) {
} }
// Override defaults with optionally provided Option functions // Override defaults with optionally provided Option functions
for _, co := range o { for _, opt := range opts {
if co == nil { if opt == nil {
continue continue
} }
if err := co(c); err != nil { if err := opt(c); err != nil {
return c, fmt.Errorf("failed to apply option: %w", err) return c, fmt.Errorf("failed to apply option: %w", err)
} }
} }
@ -231,45 +231,48 @@ func NewClient(h string, o ...Option) (*Client, error) {
} }
// WithPort overrides the default connection port // WithPort overrides the default connection port
func WithPort(p int) Option { func WithPort(port int) Option {
return func(c *Client) error { return func(c *Client) error {
if p < 1 || p > 65535 { if port < 1 || port > 65535 {
return ErrInvalidPort return ErrInvalidPort
} }
c.port = p c.port = port
return nil return nil
} }
} }
// WithTimeout overrides the default connection timeout // WithTimeout overrides the default connection timeout
func WithTimeout(t time.Duration) Option { func WithTimeout(timeout time.Duration) Option {
return func(c *Client) error { return func(c *Client) error {
if t <= 0 { if timeout <= 0 {
return ErrInvalidTimeout return ErrInvalidTimeout
} }
c.cto = t c.connTimeout = timeout
return nil return nil
} }
} }
// WithSSL tells the client to use a SSL/TLS connection // WithSSL tells the client to use a SSL/TLS connection
//
// Deprecated: use WithSSLPort instead.
func WithSSL() Option { func WithSSL() Option {
return func(c *Client) error { return func(c *Client) error {
c.ssl = true c.useSSL = true
return nil return nil
} }
} }
// WithSSLPort tells the client to use a SSL/TLS connection. // WithSSLPort tells the Client wether or not to use SSL and fallback.
// It automatically sets the port to 465. // The correct port is automatically set.
// //
// When the SSL connection fails and fallback is set to true, // Port 465 is used when SSL set (true).
// Port 25 is used when SSL is unset (false).
// When the SSL connection fails and fb is set to true,
// the client will attempt to connect on port 25 using plaintext. // the client will attempt to connect on port 25 using plaintext.
func WithSSLPort(fb bool) Option { //
// Note: If a different port has already been set otherwise, the port-choosing
// and fallback automatism will be skipped.
func WithSSLPort(fallback bool) Option {
return func(c *Client) error { return func(c *Client) error {
c.SetSSLPort(true, fb) c.SetSSLPort(true, fallback)
return nil return nil
} }
} }
@ -278,36 +281,37 @@ func WithSSLPort(fb bool) Option {
// to StdErr // to StdErr
func WithDebugLog() Option { func WithDebugLog() Option {
return func(c *Client) error { return func(c *Client) error {
c.dl = true c.useDebugLog = true
return nil return nil
} }
} }
// WithLogger overrides the default log.Logger that is used for debug logging // WithLogger overrides the default log.Logger that is used for debug logging
func WithLogger(l log.Logger) Option { func WithLogger(logger log.Logger) Option {
return func(c *Client) error { return func(c *Client) error {
c.l = l c.logger = logger
return nil return nil
} }
} }
// WithHELO tells the client to use the provided string as HELO/EHLO greeting host // WithHELO tells the client to use the provided string as HELO/EHLO greeting host
func WithHELO(h string) Option { func WithHELO(helo string) Option {
return func(c *Client) error { return func(c *Client) error {
if h == "" { if helo == "" {
return ErrInvalidHELO return ErrInvalidHELO
} }
c.helo = h c.helo = helo
return nil return nil
} }
} }
// WithTLSPolicy tells the client to use the provided TLSPolicy // WithTLSPolicy tells the client to use the provided TLSPolicy
// //
// Deprecated: use WithTLSPortPolicy instead. // Note: To follow best-practices for SMTP TLS connections, it is recommended
func WithTLSPolicy(p TLSPolicy) Option { // to use WithTLSPortPolicy instead.
func WithTLSPolicy(policy TLSPolicy) Option {
return func(c *Client) error { return func(c *Client) error {
c.tlspolicy = p c.tlspolicy = policy
return nil return nil
} }
} }
@ -319,52 +323,55 @@ func WithTLSPolicy(p TLSPolicy) Option {
// If the connection fails with TLSOpportunistic, // If the connection fails with TLSOpportunistic,
// a plaintext connection is attempted on port 25 as a fallback. // a plaintext connection is attempted on port 25 as a fallback.
// NoTLS will allways use port 25. // NoTLS will allways use port 25.
func WithTLSPortPolicy(p TLSPolicy) Option { //
// Note: If a different port has already been set otherwise, the port-choosing
// and fallback automatism will be skipped.
func WithTLSPortPolicy(policy TLSPolicy) Option {
return func(c *Client) error { return func(c *Client) error {
c.SetTLSPortPolicy(p) c.SetTLSPortPolicy(policy)
return nil return nil
} }
} }
// WithTLSConfig tells the client to use the provided *tls.Config // WithTLSConfig tells the client to use the provided *tls.Config
func WithTLSConfig(co *tls.Config) Option { func WithTLSConfig(tlsconfig *tls.Config) Option {
return func(c *Client) error { return func(c *Client) error {
if co == nil { if tlsconfig == nil {
return ErrInvalidTLSConfig return ErrInvalidTLSConfig
} }
c.tlsconfig = co c.tlsconfig = tlsconfig
return nil return nil
} }
} }
// WithSMTPAuth tells the client to use the provided SMTPAuthType for authentication // WithSMTPAuth tells the client to use the provided SMTPAuthType for authentication
func WithSMTPAuth(t SMTPAuthType) Option { func WithSMTPAuth(authtype SMTPAuthType) Option {
return func(c *Client) error { return func(c *Client) error {
c.satype = t c.smtpAuthType = authtype
return nil return nil
} }
} }
// WithSMTPAuthCustom tells the client to use the provided smtp.Auth for SMTP authentication // WithSMTPAuthCustom tells the client to use the provided smtp.Auth for SMTP authentication
func WithSMTPAuthCustom(a smtp.Auth) Option { func WithSMTPAuthCustom(smtpAuth smtp.Auth) Option {
return func(c *Client) error { return func(c *Client) error {
c.sa = a c.smtpAuth = smtpAuth
return nil return nil
} }
} }
// WithUsername tells the client to use the provided string as username for authentication // WithUsername tells the client to use the provided string as username for authentication
func WithUsername(u string) Option { func WithUsername(username string) Option {
return func(c *Client) error { return func(c *Client) error {
c.user = u c.user = username
return nil return nil
} }
} }
// WithPassword tells the client to use the provided string as password/secret for authentication // WithPassword tells the client to use the provided string as password/secret for authentication
func WithPassword(p string) Option { func WithPassword(password string) Option {
return func(c *Client) error { return func(c *Client) error {
c.pass = p c.pass = password
return nil return nil
} }
} }
@ -386,9 +393,9 @@ func WithDSN() Option {
// as described in the RFC 1891 and set the MAIL FROM Return option type to the // as described in the RFC 1891 and set the MAIL FROM Return option type to the
// given DSNMailReturnOption // given DSNMailReturnOption
// See: https://www.rfc-editor.org/rfc/rfc1891 // See: https://www.rfc-editor.org/rfc/rfc1891
func WithDSNMailReturnType(mro DSNMailReturnOption) Option { func WithDSNMailReturnType(option DSNMailReturnOption) Option {
return func(c *Client) error { return func(c *Client) error {
switch mro { switch option {
case DSNMailReturnHeadersOnly: case DSNMailReturnHeadersOnly:
case DSNMailReturnFull: case DSNMailReturnFull:
default: default:
@ -396,7 +403,7 @@ func WithDSNMailReturnType(mro DSNMailReturnOption) Option {
} }
c.dsn = true c.dsn = true
c.dsnmrtype = mro c.dsnmrtype = option
return nil return nil
} }
} }
@ -404,13 +411,13 @@ func WithDSNMailReturnType(mro DSNMailReturnOption) Option {
// WithDSNRcptNotifyType enables the Client to request DSNs as described in the RFC 1891 // WithDSNRcptNotifyType enables the Client to request DSNs as described in the RFC 1891
// and sets the RCPT TO notify options to the given list of DSNRcptNotifyOption // and sets the RCPT TO notify options to the given list of DSNRcptNotifyOption
// See: https://www.rfc-editor.org/rfc/rfc1891 // See: https://www.rfc-editor.org/rfc/rfc1891
func WithDSNRcptNotifyType(rno ...DSNRcptNotifyOption) Option { func WithDSNRcptNotifyType(opts ...DSNRcptNotifyOption) Option {
return func(c *Client) error { return func(c *Client) error {
var rnol []string var rcptOpts []string
var ns, nns bool var ns, nns bool
if len(rno) > 0 { if len(opts) > 0 {
for _, crno := range rno { for _, opt := range opts {
switch crno { switch opt {
case DSNRcptNotifyNever: case DSNRcptNotifyNever:
ns = true ns = true
case DSNRcptNotifySuccess: case DSNRcptNotifySuccess:
@ -422,7 +429,7 @@ func WithDSNRcptNotifyType(rno ...DSNRcptNotifyOption) Option {
default: default:
return ErrInvalidDSNRcptNotifyOption return ErrInvalidDSNRcptNotifyOption
} }
rnol = append(rnol, string(crno)) rcptOpts = append(rcptOpts, string(opt))
} }
} }
if ns && nns { if ns && nns {
@ -430,7 +437,7 @@ func WithDSNRcptNotifyType(rno ...DSNRcptNotifyOption) Option {
} }
c.dsn = true c.dsn = true
c.dsnrntype = rnol c.dsnrntype = rcptOpts
return nil return nil
} }
} }
@ -445,9 +452,9 @@ func WithoutNoop() Option {
} }
// WithDialContextFunc overrides the default DialContext for connecting SMTP server // WithDialContextFunc overrides the default DialContext for connecting SMTP server
func WithDialContextFunc(f DialContextFunc) Option { func WithDialContextFunc(dialCtxFunc DialContextFunc) Option {
return func(c *Client) error { return func(c *Client) error {
c.dialContextFunc = f c.dialContextFunc = dialCtxFunc
return nil return nil
} }
} }
@ -463,8 +470,11 @@ func (c *Client) ServerAddr() string {
} }
// SetTLSPolicy overrides the current TLSPolicy with the given TLSPolicy value // SetTLSPolicy overrides the current TLSPolicy with the given TLSPolicy value
func (c *Client) SetTLSPolicy(p TLSPolicy) { //
c.tlspolicy = p // Note: To follow best-practices for SMTP TLS connections, it is recommended
// to use SetTLSPortPolicy instead.
func (c *Client) SetTLSPolicy(policy TLSPolicy) {
c.tlspolicy = policy
} }
// SetTLSPortPolicy overrides the current TLSPolicy with the given TLSPolicy // SetTLSPortPolicy overrides the current TLSPolicy with the given TLSPolicy
@ -474,22 +484,27 @@ func (c *Client) SetTLSPolicy(p TLSPolicy) {
// If the connection fails with TLSOpportunistic, a plaintext connection is // If the connection fails with TLSOpportunistic, a plaintext connection is
// attempted on port 25 as a fallback. // attempted on port 25 as a fallback.
// NoTLS will allways use port 25. // NoTLS will allways use port 25.
func (c *Client) SetTLSPortPolicy(p TLSPolicy) { //
// Note: If a different port has already been set otherwise, the port-choosing
// and fallback automatism will be skipped.
func (c *Client) SetTLSPortPolicy(policy TLSPolicy) {
if c.port == DefaultPort {
c.port = DefaultPortTLS c.port = DefaultPortTLS
if p == TLSOpportunistic { if policy == TLSOpportunistic {
c.fallbackPort = DefaultPort c.fallbackPort = DefaultPort
} }
if p == NoTLS { if policy == NoTLS {
c.port = DefaultPort c.port = DefaultPort
} }
}
c.tlspolicy = p c.tlspolicy = policy
} }
// SetSSL tells the Client wether to use SSL or not // SetSSL tells the Client wether to use SSL or not
func (c *Client) SetSSL(s bool) { func (c *Client) SetSSL(ssl bool) {
c.ssl = s c.useSSL = ssl
} }
// SetSSLPort tells the Client wether or not to use SSL and fallback. // SetSSLPort tells the Client wether or not to use SSL and fallback.
@ -499,124 +514,128 @@ func (c *Client) SetSSL(s bool) {
// Port 25 is used when SSL is unset (false). // Port 25 is used when SSL is unset (false).
// When the SSL connection fails and fb is set to true, // When the SSL connection fails and fb is set to true,
// the client will attempt to connect on port 25 using plaintext. // the client will attempt to connect on port 25 using plaintext.
func (c *Client) SetSSLPort(ssl bool, fb bool) { //
c.port = DefaultPort // Note: If a different port has already been set otherwise, the port-choosing
// and fallback automatism will be skipped.
func (c *Client) SetSSLPort(ssl bool, fallback bool) {
if c.port == DefaultPort {
if ssl { if ssl {
c.port = DefaultPortSSL c.port = DefaultPortSSL
} }
c.fallbackPort = 0 c.fallbackPort = 0
if fb { if fallback {
c.fallbackPort = DefaultPort c.fallbackPort = DefaultPort
} }
}
c.ssl = ssl c.useSSL = ssl
} }
// SetDebugLog tells the Client whether debug logging is enabled or not // SetDebugLog tells the Client whether debug logging is enabled or not
func (c *Client) SetDebugLog(v bool) { func (c *Client) SetDebugLog(val bool) {
c.dl = v c.useDebugLog = val
if c.sc != nil { if c.smtpClient != nil {
c.sc.SetDebugLog(v) c.smtpClient.SetDebugLog(val)
} }
} }
// SetLogger tells the Client which log.Logger to use // SetLogger tells the Client which log.Logger to use
func (c *Client) SetLogger(l log.Logger) { func (c *Client) SetLogger(logger log.Logger) {
c.l = l c.logger = logger
if c.sc != nil { if c.smtpClient != nil {
c.sc.SetLogger(l) c.smtpClient.SetLogger(logger)
} }
} }
// SetTLSConfig overrides the current *tls.Config with the given *tls.Config value // SetTLSConfig overrides the current *tls.Config with the given *tls.Config value
func (c *Client) SetTLSConfig(co *tls.Config) error { func (c *Client) SetTLSConfig(tlsconfig *tls.Config) error {
if co == nil { if tlsconfig == nil {
return ErrInvalidTLSConfig return ErrInvalidTLSConfig
} }
c.tlsconfig = co c.tlsconfig = tlsconfig
return nil return nil
} }
// SetUsername overrides the current username string with the given value // SetUsername overrides the current username string with the given value
func (c *Client) SetUsername(u string) { func (c *Client) SetUsername(username string) {
c.user = u c.user = username
} }
// SetPassword overrides the current password string with the given value // SetPassword overrides the current password string with the given value
func (c *Client) SetPassword(p string) { func (c *Client) SetPassword(password string) {
c.pass = p c.pass = password
} }
// SetSMTPAuth overrides the current SMTP AUTH type setting with the given value // SetSMTPAuth overrides the current SMTP AUTH type setting with the given value
func (c *Client) SetSMTPAuth(a SMTPAuthType) { func (c *Client) SetSMTPAuth(authtype SMTPAuthType) {
c.satype = a c.smtpAuthType = authtype
} }
// SetSMTPAuthCustom overrides the current SMTP AUTH setting with the given custom smtp.Auth // SetSMTPAuthCustom overrides the current SMTP AUTH setting with the given custom smtp.Auth
func (c *Client) SetSMTPAuthCustom(sa smtp.Auth) { func (c *Client) SetSMTPAuthCustom(smtpAuth smtp.Auth) {
c.sa = sa c.smtpAuth = smtpAuth
} }
// setDefaultHelo retrieves the current hostname and sets it as HELO/EHLO hostname // setDefaultHelo retrieves the current hostname and sets it as HELO/EHLO hostname
func (c *Client) setDefaultHelo() error { func (c *Client) setDefaultHelo() error {
hn, err := os.Hostname() hostname, err := os.Hostname()
if err != nil { if err != nil {
return fmt.Errorf("failed cto read local hostname: %w", err) return fmt.Errorf("failed to read local hostname: %w", err)
} }
c.helo = hn c.helo = hostname
return nil return nil
} }
// DialWithContext establishes a connection cto the SMTP server with a given context.Context // DialWithContext establishes a connection to the SMTP server with a given context.Context
func (c *Client) DialWithContext(pc context.Context) error { func (c *Client) DialWithContext(dialCtx context.Context) error {
ctx, cfn := context.WithDeadline(pc, time.Now().Add(c.cto)) ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
defer cfn() defer cancel()
if c.dialContextFunc == nil { if c.dialContextFunc == nil {
nd := net.Dialer{} netDialer := net.Dialer{}
c.dialContextFunc = nd.DialContext c.dialContextFunc = netDialer.DialContext
if c.ssl { if c.useSSL {
td := tls.Dialer{NetDialer: &nd, Config: c.tlsconfig} tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig}
c.enc = true c.isEncrypted = true
c.dialContextFunc = td.DialContext c.dialContextFunc = tlsDialer.DialContext
} }
} }
var err error var err error
c.co, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr()) c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 { if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error? // TODO: should we somehow log or append the previous error?
c.co, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
} }
if err != nil { if err != nil {
return err return err
} }
sc, err := smtp.NewClient(c.co, c.host) client, err := smtp.NewClient(c.connection, c.host)
if err != nil { if err != nil {
return err return err
} }
if sc == nil { if client == nil {
return fmt.Errorf("SMTP client is nil") return fmt.Errorf("SMTP client is nil")
} }
c.sc = sc c.smtpClient = client
if c.l != nil { if c.logger != nil {
c.sc.SetLogger(c.l) c.smtpClient.SetLogger(c.logger)
} }
if c.dl { if c.useDebugLog {
c.sc.SetDebugLog(true) c.smtpClient.SetDebugLog(true)
} }
if err := c.sc.Hello(c.helo); err != nil { if err = c.smtpClient.Hello(c.helo); err != nil {
return err return err
} }
if err := c.tls(); err != nil { if err = c.tls(); err != nil {
return err return err
} }
if err := c.auth(); err != nil { if err = c.auth(); err != nil {
return err return err
} }
@ -628,7 +647,7 @@ func (c *Client) Close() error {
if err := c.checkConn(); err != nil { if err := c.checkConn(); err != nil {
return err return err
} }
if err := c.sc.Quit(); err != nil { if err := c.smtpClient.Quit(); err != nil {
return fmt.Errorf("failed to close SMTP client: %w", err) return fmt.Errorf("failed to close SMTP client: %w", err)
} }
@ -640,7 +659,7 @@ func (c *Client) Reset() error {
if err := c.checkConn(); err != nil { if err := c.checkConn(); err != nil {
return err return err
} }
if err := c.sc.Reset(); err != nil { if err := c.smtpClient.Reset(); err != nil {
return fmt.Errorf("failed to send RSET to SMTP client: %w", err) return fmt.Errorf("failed to send RSET to SMTP client: %w", err)
} }
@ -649,18 +668,18 @@ func (c *Client) Reset() error {
// DialAndSend establishes a connection to the SMTP server with a // DialAndSend establishes a connection to the SMTP server with a
// default context.Background and sends the mail // default context.Background and sends the mail
func (c *Client) DialAndSend(ml ...*Msg) error { func (c *Client) DialAndSend(messages ...*Msg) error {
ctx := context.Background() ctx := context.Background()
return c.DialAndSendWithContext(ctx, ml...) return c.DialAndSendWithContext(ctx, messages...)
} }
// DialAndSendWithContext establishes a connection to the SMTP server with a // DialAndSendWithContext establishes a connection to the SMTP server with a
// custom context and sends the mail // custom context and sends the mail
func (c *Client) DialAndSendWithContext(ctx context.Context, ml ...*Msg) error { func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error {
if err := c.DialWithContext(ctx); err != nil { if err := c.DialWithContext(ctx); err != nil {
return fmt.Errorf("dial failed: %w", err) return fmt.Errorf("dial failed: %w", err)
} }
if err := c.Send(ml...); err != nil { if err := c.Send(messages...); err != nil {
return fmt.Errorf("send failed: %w", err) return fmt.Errorf("send failed: %w", err)
} }
if err := c.Close(); err != nil { if err := c.Close(); err != nil {
@ -672,17 +691,17 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, ml ...*Msg) error {
// checkConn makes sure that a required server connection is available and extends the // checkConn makes sure that a required server connection is available and extends the
// connection deadline // connection deadline
func (c *Client) checkConn() error { func (c *Client) checkConn() error {
if c.co == nil { if c.connection == nil {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
if !c.noNoop { if !c.noNoop {
if err := c.sc.Noop(); err != nil { if err := c.smtpClient.Noop(); err != nil {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
} }
if err := c.co.SetDeadline(time.Now().Add(c.cto)); err != nil { if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil {
return ErrDeadlineExtendFailed return ErrDeadlineExtendFailed
} }
return nil return nil
@ -696,30 +715,30 @@ func (c *Client) serverFallbackAddr() string {
// tls tries to make sure that the STARTTLS requirements are satisfied // tls tries to make sure that the STARTTLS requirements are satisfied
func (c *Client) tls() error { func (c *Client) tls() error {
if c.co == nil { if c.connection == nil {
return ErrNoActiveConnection return ErrNoActiveConnection
} }
if !c.ssl && c.tlspolicy != NoTLS { if !c.useSSL && c.tlspolicy != NoTLS {
est := false hasStartTLS := false
st, _ := c.sc.Extension("STARTTLS") extension, _ := c.smtpClient.Extension("STARTTLS")
if c.tlspolicy == TLSMandatory { if c.tlspolicy == TLSMandatory {
est = true hasStartTLS = true
if !st { if !extension {
return fmt.Errorf("STARTTLS mode set to: %q, but target host does not support STARTTLS", return fmt.Errorf("STARTTLS mode set to: %q, but target host does not support STARTTLS",
c.tlspolicy) c.tlspolicy)
} }
} }
if c.tlspolicy == TLSOpportunistic { if c.tlspolicy == TLSOpportunistic {
if st { if extension {
est = true hasStartTLS = true
} }
} }
if est { if hasStartTLS {
if err := c.sc.StartTLS(c.tlsconfig); err != nil { if err := c.smtpClient.StartTLS(c.tlsconfig); err != nil {
return err return err
} }
} }
_, c.enc = c.sc.TLSConnectionState() _, c.isEncrypted = c.smtpClient.TLSConnectionState()
} }
return nil return nil
} }
@ -729,40 +748,40 @@ func (c *Client) auth() error {
if err := c.checkConn(); err != nil { if err := c.checkConn(); err != nil {
return fmt.Errorf("failed to authenticate: %w", err) return fmt.Errorf("failed to authenticate: %w", err)
} }
if c.sa == nil && c.satype != "" { if c.smtpAuth == nil && c.smtpAuthType != "" {
sa, sat := c.sc.Extension("AUTH") hasSMTPAuth, smtpAuthType := c.smtpClient.Extension("AUTH")
if !sa { if !hasSMTPAuth {
return fmt.Errorf("server does not support SMTP AUTH") return fmt.Errorf("server does not support SMTP AUTH")
} }
switch c.satype { switch c.smtpAuthType {
case SMTPAuthPlain: case SMTPAuthPlain:
if !strings.Contains(sat, string(SMTPAuthPlain)) { if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) {
return ErrPlainAuthNotSupported return ErrPlainAuthNotSupported
} }
c.sa = smtp.PlainAuth("", c.user, c.pass, c.host) c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host)
case SMTPAuthLogin: case SMTPAuthLogin:
if !strings.Contains(sat, string(SMTPAuthLogin)) { if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) {
return ErrLoginAuthNotSupported return ErrLoginAuthNotSupported
} }
c.sa = smtp.LoginAuth(c.user, c.pass, c.host) c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host)
case SMTPAuthCramMD5: case SMTPAuthCramMD5:
if !strings.Contains(sat, string(SMTPAuthCramMD5)) { if !strings.Contains(smtpAuthType, string(SMTPAuthCramMD5)) {
return ErrCramMD5AuthNotSupported return ErrCramMD5AuthNotSupported
} }
c.sa = smtp.CRAMMD5Auth(c.user, c.pass) c.smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass)
case SMTPAuthXOAUTH2: case SMTPAuthXOAUTH2:
if !strings.Contains(sat, string(SMTPAuthXOAUTH2)) { if !strings.Contains(smtpAuthType, string(SMTPAuthXOAUTH2)) {
return ErrXOauth2AuthNotSupported return ErrXOauth2AuthNotSupported
} }
c.sa = smtp.XOAuth2Auth(c.user, c.pass) c.smtpAuth = smtp.XOAuth2Auth(c.user, c.pass)
default: default:
return fmt.Errorf("unsupported SMTP AUTH type %q", c.satype) return fmt.Errorf("unsupported SMTP AUTH type %q", c.smtpAuthType)
} }
} }
if c.sa != nil { if c.smtpAuth != nil {
if err := c.sc.Auth(c.sa); err != nil { if err := c.smtpClient.Auth(c.smtpAuth); err != nil {
return fmt.Errorf("SMTP AUTH failed: %w", err) return fmt.Errorf("SMTP AUTH failed: %w", err)
} }
} }

View file

@ -10,123 +10,123 @@ package mail
import "strings" import "strings"
// Send sends out the mail message // Send sends out the mail message
func (c *Client) Send(ml ...*Msg) error { func (c *Client) Send(messages ...*Msg) error {
if cerr := c.checkConn(); cerr != nil { if cerr := c.checkConn(); cerr != nil {
return &SendError{Reason: ErrConnCheck, errlist: []error{cerr}, isTemp: isTempError(cerr)} return &SendError{Reason: ErrConnCheck, errlist: []error{cerr}, isTemp: isTempError(cerr)}
} }
var errs []*SendError var errs []*SendError
for _, m := range ml { for _, message := range messages {
m.sendError = nil message.sendError = nil
if m.encoding == NoEncoding { if message.encoding == NoEncoding {
if ok, _ := c.sc.Extension("8BITMIME"); !ok { if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok {
se := &SendError{Reason: ErrNoUnencoded, isTemp: false} sendErr := &SendError{Reason: ErrNoUnencoded, isTemp: false}
m.sendError = se message.sendError = sendErr
errs = append(errs, se) errs = append(errs, sendErr)
continue continue
} }
} }
f, err := m.GetSender(false) from, err := message.GetSender(false)
if err != nil { if err != nil {
se := &SendError{Reason: ErrGetSender, errlist: []error{err}, isTemp: isTempError(err)} sendErr := &SendError{Reason: ErrGetSender, errlist: []error{err}, isTemp: isTempError(err)}
m.sendError = se message.sendError = sendErr
errs = append(errs, se) errs = append(errs, sendErr)
continue continue
} }
rl, err := m.GetRecipients() rcpts, err := message.GetRecipients()
if err != nil { if err != nil {
se := &SendError{Reason: ErrGetRcpts, errlist: []error{err}, isTemp: isTempError(err)} sendErr := &SendError{Reason: ErrGetRcpts, errlist: []error{err}, isTemp: isTempError(err)}
m.sendError = se message.sendError = sendErr
errs = append(errs, se) errs = append(errs, sendErr)
continue continue
} }
if c.dsn { if c.dsn {
if c.dsnmrtype != "" { if c.dsnmrtype != "" {
c.sc.SetDSNMailReturnOption(string(c.dsnmrtype)) c.smtpClient.SetDSNMailReturnOption(string(c.dsnmrtype))
} }
} }
if err := c.sc.Mail(f); err != nil { if err = c.smtpClient.Mail(from); err != nil {
se := &SendError{Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err)} sendErr := &SendError{Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err)}
if reserr := c.sc.Reset(); reserr != nil { if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil {
se.errlist = append(se.errlist, reserr) sendErr.errlist = append(sendErr.errlist, resetSendErr)
} }
m.sendError = se message.sendError = sendErr
errs = append(errs, se) errs = append(errs, sendErr)
continue continue
} }
failed := false failed := false
rse := &SendError{} rcptSendErr := &SendError{}
rse.errlist = make([]error, 0) rcptSendErr.errlist = make([]error, 0)
rse.rcpt = make([]string, 0) rcptSendErr.rcpt = make([]string, 0)
rno := strings.Join(c.dsnrntype, ",") rcptNotifyOpt := strings.Join(c.dsnrntype, ",")
c.sc.SetDSNRcptNotifyOption(rno) c.smtpClient.SetDSNRcptNotifyOption(rcptNotifyOpt)
for _, r := range rl { for _, rcpt := range rcpts {
if err := c.sc.Rcpt(r); err != nil { if err = c.smtpClient.Rcpt(rcpt); err != nil {
rse.Reason = ErrSMTPRcptTo rcptSendErr.Reason = ErrSMTPRcptTo
rse.errlist = append(rse.errlist, err) rcptSendErr.errlist = append(rcptSendErr.errlist, err)
rse.rcpt = append(rse.rcpt, r) rcptSendErr.rcpt = append(rcptSendErr.rcpt, rcpt)
rse.isTemp = isTempError(err) rcptSendErr.isTemp = isTempError(err)
failed = true failed = true
} }
} }
if failed { if failed {
if reserr := c.sc.Reset(); reserr != nil { if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil {
rse.errlist = append(rse.errlist, err) rcptSendErr.errlist = append(rcptSendErr.errlist, err)
} }
m.sendError = rse message.sendError = rcptSendErr
errs = append(errs, rse) errs = append(errs, rcptSendErr)
continue continue
} }
w, err := c.sc.Data() writer, err := c.smtpClient.Data()
if err != nil { if err != nil {
se := &SendError{Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err)} sendErr := &SendError{Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err)}
m.sendError = se message.sendError = sendErr
errs = append(errs, se) errs = append(errs, sendErr)
continue continue
} }
_, err = m.WriteTo(w) _, err = message.WriteTo(writer)
if err != nil { if err != nil {
se := &SendError{Reason: ErrWriteContent, errlist: []error{err}, isTemp: isTempError(err)} sendErr := &SendError{Reason: ErrWriteContent, errlist: []error{err}, isTemp: isTempError(err)}
m.sendError = se message.sendError = sendErr
errs = append(errs, se) errs = append(errs, sendErr)
continue continue
} }
m.isDelivered = true message.isDelivered = true
if err := w.Close(); err != nil { if err = writer.Close(); err != nil {
se := &SendError{Reason: ErrSMTPDataClose, errlist: []error{err}, isTemp: isTempError(err)} sendErr := &SendError{Reason: ErrSMTPDataClose, errlist: []error{err}, isTemp: isTempError(err)}
m.sendError = se message.sendError = sendErr
errs = append(errs, se) errs = append(errs, sendErr)
continue continue
} }
if err := c.Reset(); err != nil { if err = c.Reset(); err != nil {
se := &SendError{Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err)} sendErr := &SendError{Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err)}
m.sendError = se message.sendError = sendErr
errs = append(errs, se) errs = append(errs, sendErr)
continue continue
} }
if err := c.checkConn(); err != nil { if err = c.checkConn(); err != nil {
se := &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} sendErr := &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)}
m.sendError = se message.sendError = sendErr
errs = append(errs, se) errs = append(errs, sendErr)
continue continue
} }
} }
if len(errs) > 0 { if len(errs) > 0 {
if len(errs) > 1 { if len(errs) > 1 {
re := &SendError{Reason: ErrAmbiguous} returnErr := &SendError{Reason: ErrAmbiguous}
for i := range errs { for i := range errs {
re.errlist = append(re.errlist, errs[i].errlist...) returnErr.errlist = append(returnErr.errlist, errs[i].errlist...)
re.rcpt = append(re.rcpt, errs[i].rcpt...) returnErr.rcpt = append(returnErr.rcpt, errs[i].rcpt...)
} }
// We assume that the isTemp flag from the last error we received should be the // We assume that the isTemp flag from the last error we received should be the
// indicator for the returned isTemp flag as well // indicator for the returned isTemp flag as well
re.isTemp = errs[len(errs)-1].isTemp returnErr.isTemp = errs[len(errs)-1].isTemp
return re return returnErr
} }
return errs[0] return errs[0]
} }

View file

@ -13,97 +13,97 @@ import (
) )
// Send sends out the mail message // Send sends out the mail message
func (c *Client) Send(ml ...*Msg) (rerr error) { func (c *Client) Send(messages ...*Msg) (returnErr error) {
if err := c.checkConn(); err != nil { if err := c.checkConn(); err != nil {
rerr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)}
return return
} }
for _, m := range ml { for _, message := range messages {
m.sendError = nil message.sendError = nil
if m.encoding == NoEncoding { if message.encoding == NoEncoding {
if ok, _ := c.sc.Extension("8BITMIME"); !ok { if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok {
m.sendError = &SendError{Reason: ErrNoUnencoded, isTemp: false} message.sendError = &SendError{Reason: ErrNoUnencoded, isTemp: false}
rerr = errors.Join(rerr, m.sendError) returnErr = errors.Join(returnErr, message.sendError)
continue continue
} }
} }
f, err := m.GetSender(false) from, err := message.GetSender(false)
if err != nil { if err != nil {
m.sendError = &SendError{Reason: ErrGetSender, errlist: []error{err}, isTemp: isTempError(err)} message.sendError = &SendError{Reason: ErrGetSender, errlist: []error{err}, isTemp: isTempError(err)}
rerr = errors.Join(rerr, m.sendError) returnErr = errors.Join(returnErr, message.sendError)
continue continue
} }
rl, err := m.GetRecipients() rcpts, err := message.GetRecipients()
if err != nil { if err != nil {
m.sendError = &SendError{Reason: ErrGetRcpts, errlist: []error{err}, isTemp: isTempError(err)} message.sendError = &SendError{Reason: ErrGetRcpts, errlist: []error{err}, isTemp: isTempError(err)}
rerr = errors.Join(rerr, m.sendError) returnErr = errors.Join(returnErr, message.sendError)
continue continue
} }
if c.dsn { if c.dsn {
if c.dsnmrtype != "" { if c.dsnmrtype != "" {
c.sc.SetDSNMailReturnOption(string(c.dsnmrtype)) c.smtpClient.SetDSNMailReturnOption(string(c.dsnmrtype))
} }
} }
if err := c.sc.Mail(f); err != nil { if err = c.smtpClient.Mail(from); err != nil {
m.sendError = &SendError{Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err)} message.sendError = &SendError{Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err)}
rerr = errors.Join(rerr, m.sendError) returnErr = errors.Join(returnErr, message.sendError)
if reserr := c.sc.Reset(); reserr != nil { if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil {
rerr = errors.Join(rerr, reserr) returnErr = errors.Join(returnErr, resetSendErr)
} }
continue continue
} }
failed := false failed := false
rse := &SendError{} rcptSendErr := &SendError{}
rse.errlist = make([]error, 0) rcptSendErr.errlist = make([]error, 0)
rse.rcpt = make([]string, 0) rcptSendErr.rcpt = make([]string, 0)
rno := strings.Join(c.dsnrntype, ",") rcptNotifyOpt := strings.Join(c.dsnrntype, ",")
c.sc.SetDSNRcptNotifyOption(rno) c.smtpClient.SetDSNRcptNotifyOption(rcptNotifyOpt)
for _, r := range rl { for _, rcpt := range rcpts {
if err := c.sc.Rcpt(r); err != nil { if err = c.smtpClient.Rcpt(rcpt); err != nil {
rse.Reason = ErrSMTPRcptTo rcptSendErr.Reason = ErrSMTPRcptTo
rse.errlist = append(rse.errlist, err) rcptSendErr.errlist = append(rcptSendErr.errlist, err)
rse.rcpt = append(rse.rcpt, r) rcptSendErr.rcpt = append(rcptSendErr.rcpt, rcpt)
rse.isTemp = isTempError(err) rcptSendErr.isTemp = isTempError(err)
failed = true failed = true
} }
} }
if failed { if failed {
if reserr := c.sc.Reset(); reserr != nil { if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil {
rerr = errors.Join(rerr, reserr) returnErr = errors.Join(returnErr, resetSendErr)
} }
m.sendError = rse message.sendError = rcptSendErr
rerr = errors.Join(rerr, m.sendError) returnErr = errors.Join(returnErr, message.sendError)
continue continue
} }
w, err := c.sc.Data() writer, err := c.smtpClient.Data()
if err != nil { if err != nil {
m.sendError = &SendError{Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err)} message.sendError = &SendError{Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err)}
rerr = errors.Join(rerr, m.sendError) returnErr = errors.Join(returnErr, message.sendError)
continue continue
} }
_, err = m.WriteTo(w) _, err = message.WriteTo(writer)
if err != nil { if err != nil {
m.sendError = &SendError{Reason: ErrWriteContent, errlist: []error{err}, isTemp: isTempError(err)} message.sendError = &SendError{Reason: ErrWriteContent, errlist: []error{err}, isTemp: isTempError(err)}
rerr = errors.Join(rerr, m.sendError) returnErr = errors.Join(returnErr, message.sendError)
continue continue
} }
m.isDelivered = true message.isDelivered = true
if err := w.Close(); err != nil { if err = writer.Close(); err != nil {
m.sendError = &SendError{Reason: ErrSMTPDataClose, errlist: []error{err}, isTemp: isTempError(err)} message.sendError = &SendError{Reason: ErrSMTPDataClose, errlist: []error{err}, isTemp: isTempError(err)}
rerr = errors.Join(rerr, m.sendError) returnErr = errors.Join(returnErr, message.sendError)
continue continue
} }
if err := c.Reset(); err != nil { if err = c.Reset(); err != nil {
m.sendError = &SendError{Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err)} message.sendError = &SendError{Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err)}
rerr = errors.Join(rerr, m.sendError) returnErr = errors.Join(returnErr, message.sendError)
continue continue
} }
if err := c.checkConn(); err != nil { if err = c.checkConn(); err != nil {
m.sendError = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} message.sendError = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)}
rerr = errors.Join(rerr, m.sendError) returnErr = errors.Join(returnErr, message.sendError)
} }
} }

View file

@ -49,9 +49,9 @@ func TestNewClient(t *testing.T) {
if c.host != tt.host { if c.host != tt.host {
t.Errorf("failed to create new client. Host expected: %s, got: %s", host, c.host) t.Errorf("failed to create new client. Host expected: %s, got: %s", host, c.host)
} }
if c.cto != DefaultTimeout { if c.connTimeout != DefaultTimeout {
t.Errorf("failed to create new client. Timeout expected: %s, got: %s", DefaultTimeout.String(), t.Errorf("failed to create new client. Timeout expected: %s, got: %s", DefaultTimeout.String(),
c.cto.String()) c.connTimeout.String())
} }
if c.port != DefaultPort { if c.port != DefaultPort {
t.Errorf("failed to create new client. Port expected: %d, got: %d", DefaultPort, c.port) t.Errorf("failed to create new client. Port expected: %d, got: %d", DefaultPort, c.port)
@ -205,8 +205,8 @@ func TestWithTimeout(t *testing.T) {
t.Errorf("failed to create new client: %s", err) t.Errorf("failed to create new client: %s", err)
return return
} }
if c.cto != tt.want { if c.connTimeout != tt.want {
t.Errorf("failed to set custom timeout. Want: %d, got: %d", tt.want, c.cto) t.Errorf("failed to set custom timeout. Want: %d, got: %d", tt.want, c.connTimeout)
} }
}) })
} }
@ -345,8 +345,8 @@ func TestSetSSL(t *testing.T) {
return return
} }
c.SetSSL(tt.value) c.SetSSL(tt.value)
if c.ssl != tt.value { if c.useSSL != tt.value {
t.Errorf("failed to set SSL setting. Got: %t, want: %t", c.ssl, tt.value) t.Errorf("failed to set SSL setting. Got: %t, want: %t", c.useSSL, tt.value)
} }
}) })
} }
@ -374,8 +374,8 @@ func TestClient_SetSSLPort(t *testing.T) {
return return
} }
c.SetSSLPort(tt.value, tt.fb) c.SetSSLPort(tt.value, tt.fb)
if c.ssl != tt.value { if c.useSSL != tt.value {
t.Errorf("failed to set SSL setting. Got: %t, want: %t", c.ssl, tt.value) t.Errorf("failed to set SSL setting. Got: %t, want: %t", c.useSSL, tt.value)
} }
if c.port != tt.port { if c.port != tt.port {
t.Errorf("failed to set SSLPort, wanted port: %d, got: %d", c.port, tt.port) t.Errorf("failed to set SSLPort, wanted port: %d, got: %d", c.port, tt.port)
@ -460,8 +460,8 @@ func TestSetSMTPAuth(t *testing.T) {
return return
} }
c.SetSMTPAuth(tt.value) c.SetSMTPAuth(tt.value)
if string(c.satype) != tt.want { if string(c.smtpAuthType) != tt.want {
t.Errorf("failed to set SMTP auth type. Expected %s, got: %s", tt.want, string(c.satype)) t.Errorf("failed to set SMTP auth type. Expected %s, got: %s", tt.want, string(c.smtpAuthType))
} }
}) })
} }
@ -590,10 +590,10 @@ func TestSetSMTPAuthCustom(t *testing.T) {
return return
} }
c.SetSMTPAuthCustom(tt.value) c.SetSMTPAuthCustom(tt.value)
if c.sa == nil { if c.smtpAuth == nil {
t.Errorf("failed to set custom SMTP auth method. SMTP Auth method is empty") t.Errorf("failed to set custom SMTP auth method. SMTP Auth method is empty")
} }
p, _, err := c.sa.Start(&si) p, _, err := c.smtpAuth.Start(&si)
if err != nil { if err != nil {
t.Errorf("SMTP Auth Start() method returned error: %s", err) t.Errorf("SMTP Auth Start() method returned error: %s", err)
} }
@ -615,10 +615,10 @@ func TestClient_DialWithContext(t *testing.T) {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.co == nil { if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no connection found.") t.Errorf("DialWithContext didn't fail but no connection found.")
} }
if c.sc == nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
} }
if err := c.Close(); err != nil { if err := c.Close(); err != nil {
@ -640,10 +640,10 @@ func TestClient_DialWithContext_Fallback(t *testing.T) {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.co == nil { if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no connection found.") t.Errorf("DialWithContext didn't fail but no connection found.")
} }
if c.sc == nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
} }
if err := c.Close(); err != nil { if err := c.Close(); err != nil {
@ -670,10 +670,10 @@ func TestClient_DialWithContext_Debug(t *testing.T) {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.co == nil { if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no connection found.") t.Errorf("DialWithContext didn't fail but no connection found.")
} }
if c.sc == nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
} }
c.SetDebugLog(true) c.SetDebugLog(true)
@ -694,10 +694,10 @@ func TestClient_DialWithContext_Debug_custom(t *testing.T) {
t.Errorf("failed to dial with context: %s", err) t.Errorf("failed to dial with context: %s", err)
return return
} }
if c.co == nil { if c.connection == nil {
t.Errorf("DialWithContext didn't fail but no connection found.") t.Errorf("DialWithContext didn't fail but no connection found.")
} }
if c.sc == nil { if c.smtpClient == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
} }
c.SetDebugLog(true) c.SetDebugLog(true)
@ -714,7 +714,7 @@ func TestClient_DialWithContextInvalidHost(t *testing.T) {
if err != nil { if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
c.co = nil c.connection = nil
c.host = "invalid.addr" c.host = "invalid.addr"
ctx := context.Background() ctx := context.Background()
if err := c.DialWithContext(ctx); err == nil { if err := c.DialWithContext(ctx); err == nil {
@ -730,7 +730,7 @@ func TestClient_DialWithContextInvalidHELO(t *testing.T) {
if err != nil { if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
c.co = nil c.connection = nil
c.helo = "" c.helo = ""
ctx := context.Background() ctx := context.Background()
if err := c.DialWithContext(ctx); err == nil { if err := c.DialWithContext(ctx); err == nil {
@ -762,7 +762,7 @@ func TestClient_checkConn(t *testing.T) {
if err != nil { if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err) t.Skipf("failed to create test client: %s. Skipping tests", err)
} }
c.co = nil c.connection = nil
if err := c.checkConn(); err == nil { if err := c.checkConn(); err == nil {
t.Errorf("connCheck() should fail but succeeded") t.Errorf("connCheck() should fail but succeeded")
} }
@ -799,10 +799,10 @@ func TestClient_DialWithContextOptions(t *testing.T) {
return return
} }
if !tt.sf { if !tt.sf {
if c.co == nil && !tt.sf { if c.connection == nil && !tt.sf {
t.Errorf("DialWithContext didn't fail but no connection found.") t.Errorf("DialWithContext didn't fail but no connection found.")
} }
if c.sc == nil && !tt.sf { if c.smtpClient == nil && !tt.sf {
t.Errorf("DialWithContext didn't fail but no SMTP client found.") t.Errorf("DialWithContext didn't fail but no SMTP client found.")
} }
if err := c.Reset(); err != nil { if err := c.Reset(); err != nil {
@ -1002,16 +1002,16 @@ func TestClient_DialSendCloseBroken(t *testing.T) {
return return
} }
if tt.closestart { if tt.closestart {
_ = c.sc.Close() _ = c.smtpClient.Close()
_ = c.co.Close() _ = c.connection.Close()
} }
if err := c.Send(m); err != nil && !tt.sf { if err := c.Send(m); err != nil && !tt.sf {
t.Errorf("Send() failed: %s", err) t.Errorf("Send() failed: %s", err)
return return
} }
if tt.closeearly { if tt.closeearly {
_ = c.sc.Close() _ = c.smtpClient.Close()
_ = c.co.Close() _ = c.connection.Close()
} }
if err := c.Close(); err != nil && !tt.sf { if err := c.Close(); err != nil && !tt.sf {
t.Errorf("Close() failed: %s", err) t.Errorf("Close() failed: %s", err)
@ -1062,16 +1062,16 @@ func TestClient_DialSendCloseBrokenWithDSN(t *testing.T) {
return return
} }
if tt.closestart { if tt.closestart {
_ = c.sc.Close() _ = c.smtpClient.Close()
_ = c.co.Close() _ = c.connection.Close()
} }
if err := c.Send(m); err != nil && !tt.sf { if err := c.Send(m); err != nil && !tt.sf {
t.Errorf("Send() failed: %s", err) t.Errorf("Send() failed: %s", err)
return return
} }
if tt.closeearly { if tt.closeearly {
_ = c.sc.Close() _ = c.smtpClient.Close()
_ = c.co.Close() _ = c.connection.Close()
} }
if err := c.Close(); err != nil && !tt.sf { if err := c.Close(); err != nil && !tt.sf {
t.Errorf("Close() failed: %s", err) t.Errorf("Close() failed: %s", err)

26
file.go
View file

@ -23,17 +23,17 @@ type File struct {
} }
// WithFileName sets the filename of the File // WithFileName sets the filename of the File
func WithFileName(n string) FileOption { func WithFileName(name string) FileOption {
return func(f *File) { return func(f *File) {
f.Name = n f.Name = name
} }
} }
// WithFileDescription sets an optional file description of the File that will be // WithFileDescription sets an optional file description of the File that will be
// added as Content-Description part // added as Content-Description part
func WithFileDescription(d string) FileOption { func WithFileDescription(description string) FileOption {
return func(f *File) { return func(f *File) {
f.Desc = d f.Desc = description
} }
} }
@ -41,12 +41,12 @@ func WithFileDescription(d string) FileOption {
// Base64 encoding but there might be exceptions, where this might come handy. // Base64 encoding but there might be exceptions, where this might come handy.
// Please note that quoted-printable should never be used for attachments/embeds. If this // Please note that quoted-printable should never be used for attachments/embeds. If this
// is provided as argument, the function will automatically override back to Base64 // is provided as argument, the function will automatically override back to Base64
func WithFileEncoding(e Encoding) FileOption { func WithFileEncoding(encoding Encoding) FileOption {
return func(f *File) { return func(f *File) {
if e == EncodingQP { if encoding == EncodingQP {
return return
} }
f.Enc = e f.Enc = encoding
} }
} }
@ -56,19 +56,19 @@ func WithFileEncoding(e Encoding) FileOption {
// could not be guessed. In some cases, however, it might be needed to force // could not be guessed. In some cases, however, it might be needed to force
// this to a specific type. For such situations this override method can // this to a specific type. For such situations this override method can
// be used // be used
func WithFileContentType(t ContentType) FileOption { func WithFileContentType(contentType ContentType) FileOption {
return func(f *File) { return func(f *File) {
f.ContentType = t f.ContentType = contentType
} }
} }
// setHeader sets header fields to a File // setHeader sets header fields to a File
func (f *File) setHeader(h Header, v string) { func (f *File) setHeader(header Header, value string) {
f.Header.Set(string(h), v) f.Header.Set(string(header), value)
} }
// getHeader return header fields of a File // getHeader return header fields of a File
func (f *File) getHeader(h Header) (string, bool) { func (f *File) getHeader(header Header) (string, bool) {
v := f.Header.Get(string(h)) v := f.Header.Get(string(header))
return v, v != "" return v, v != ""
} }

View file

@ -15,68 +15,68 @@ import (
// JSONlog is the default structured JSON logger that satisfies the Logger interface // JSONlog is the default structured JSON logger that satisfies the Logger interface
type JSONlog struct { type JSONlog struct {
l Level level Level
log *slog.Logger log *slog.Logger
} }
// NewJSON returns a new JSONlog type that satisfies the Logger interface // NewJSON returns a new JSONlog type that satisfies the Logger interface
func NewJSON(o io.Writer, l Level) *JSONlog { func NewJSON(output io.Writer, level Level) *JSONlog {
lo := slog.HandlerOptions{} logOpts := slog.HandlerOptions{}
switch l { switch level {
case LevelDebug: case LevelDebug:
lo.Level = slog.LevelDebug logOpts.Level = slog.LevelDebug
case LevelInfo: case LevelInfo:
lo.Level = slog.LevelInfo logOpts.Level = slog.LevelInfo
case LevelWarn: case LevelWarn:
lo.Level = slog.LevelWarn logOpts.Level = slog.LevelWarn
case LevelError: case LevelError:
lo.Level = slog.LevelError logOpts.Level = slog.LevelError
default: default:
lo.Level = slog.LevelDebug logOpts.Level = slog.LevelDebug
} }
lh := slog.NewJSONHandler(o, &lo) logHandler := slog.NewJSONHandler(output, &logOpts)
return &JSONlog{ return &JSONlog{
l: l, level: level,
log: slog.New(lh), log: slog.New(logHandler),
} }
} }
// Debugf logs a debug message via the structured JSON logger // Debugf logs a debug message via the structured JSON logger
func (l *JSONlog) Debugf(lo Log) { func (l *JSONlog) Debugf(log Log) {
if l.l >= LevelDebug { if l.level >= LevelDebug {
l.log.WithGroup(DirString).With( l.log.WithGroup(DirString).With(
slog.String(DirFromString, lo.directionFrom()), slog.String(DirFromString, log.directionFrom()),
slog.String(DirToString, lo.directionTo()), slog.String(DirToString, log.directionTo()),
).Debug(fmt.Sprintf(lo.Format, lo.Messages...)) ).Debug(fmt.Sprintf(log.Format, log.Messages...))
} }
} }
// Infof logs a info message via the structured JSON logger // Infof logs a info message via the structured JSON logger
func (l *JSONlog) Infof(lo Log) { func (l *JSONlog) Infof(log Log) {
if l.l >= LevelInfo { if l.level >= LevelInfo {
l.log.WithGroup(DirString).With( l.log.WithGroup(DirString).With(
slog.String(DirFromString, lo.directionFrom()), slog.String(DirFromString, log.directionFrom()),
slog.String(DirToString, lo.directionTo()), slog.String(DirToString, log.directionTo()),
).Info(fmt.Sprintf(lo.Format, lo.Messages...)) ).Info(fmt.Sprintf(log.Format, log.Messages...))
} }
} }
// Warnf logs a warn message via the structured JSON logger // Warnf logs a warn message via the structured JSON logger
func (l *JSONlog) Warnf(lo Log) { func (l *JSONlog) Warnf(log Log) {
if l.l >= LevelWarn { if l.level >= LevelWarn {
l.log.WithGroup(DirString).With( l.log.WithGroup(DirString).With(
slog.String(DirFromString, lo.directionFrom()), slog.String(DirFromString, log.directionFrom()),
slog.String(DirToString, lo.directionTo()), slog.String(DirToString, log.directionTo()),
).Warn(fmt.Sprintf(lo.Format, lo.Messages...)) ).Warn(fmt.Sprintf(log.Format, log.Messages...))
} }
} }
// Errorf logs a warn message via the structured JSON logger // Errorf logs a warn message via the structured JSON logger
func (l *JSONlog) Errorf(lo Log) { func (l *JSONlog) Errorf(log Log) {
if l.l >= LevelError { if l.level >= LevelError {
l.log.WithGroup(DirString).With( l.log.WithGroup(DirString).With(
slog.String(DirFromString, lo.directionFrom()), slog.String(DirFromString, log.directionFrom()),
slog.String(DirToString, lo.directionTo()), slog.String(DirToString, log.directionTo()),
).Error(fmt.Sprintf(lo.Format, lo.Messages...)) ).Error(fmt.Sprintf(log.Format, log.Messages...))
} }
} }

View file

@ -30,8 +30,8 @@ type jsonDir struct {
func TestNewJSON(t *testing.T) { func TestNewJSON(t *testing.T) {
var b bytes.Buffer var b bytes.Buffer
l := NewJSON(&b, LevelDebug) l := NewJSON(&b, LevelDebug)
if l.l != LevelDebug { if l.level != LevelDebug {
t.Error("Expected level to be LevelDebug, got ", l.l) t.Error("Expected level to be LevelDebug, got ", l.level)
} }
if l.log == nil { if l.log == nil {
t.Error("logger not initialized") t.Error("logger not initialized")
@ -81,7 +81,7 @@ func TestJSONDebugf(t *testing.T) {
} }
b.Reset() b.Reset()
l.l = LevelInfo l.level = LevelInfo
l.Debugf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}}) l.Debugf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}})
if b.String() != "" { if b.String() != "" {
t.Error("Debug message was not expected to be logged") t.Error("Debug message was not expected to be logged")
@ -131,7 +131,7 @@ func TestJSONDebugf_WithDefault(t *testing.T) {
} }
b.Reset() b.Reset()
l.l = LevelInfo l.level = LevelInfo
l.Debugf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}}) l.Debugf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}})
if b.String() != "" { if b.String() != "" {
t.Error("Debug message was not expected to be logged") t.Error("Debug message was not expected to be logged")
@ -181,7 +181,7 @@ func TestJSONInfof(t *testing.T) {
} }
b.Reset() b.Reset()
l.l = LevelWarn l.level = LevelWarn
l.Infof(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}}) l.Infof(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}})
if b.String() != "" { if b.String() != "" {
t.Error("Info message was not expected to be logged") t.Error("Info message was not expected to be logged")
@ -231,7 +231,7 @@ func TestJSONWarnf(t *testing.T) {
} }
b.Reset() b.Reset()
l.l = LevelError l.level = LevelError
l.Warnf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}}) l.Warnf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}})
if b.String() != "" { if b.String() != "" {
t.Error("Warn message was not expected to be logged") t.Error("Warn message was not expected to be logged")
@ -281,7 +281,7 @@ func TestJSONErrorf(t *testing.T) {
} }
b.Reset() b.Reset()
l.l = -99 l.level = -99
l.Errorf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}}) l.Errorf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}})
if b.String() != "" { if b.String() != "" {
t.Error("Error message was not expected to be logged") t.Error("Error message was not expected to be logged")

View file

@ -12,7 +12,7 @@ import (
// Stdlog is the default logger that satisfies the Logger interface // Stdlog is the default logger that satisfies the Logger interface
type Stdlog struct { type Stdlog struct {
l Level level Level
err *log.Logger err *log.Logger
warn *log.Logger warn *log.Logger
info *log.Logger info *log.Logger
@ -24,45 +24,45 @@ type Stdlog struct {
const CallDepth = 2 const CallDepth = 2
// New returns a new Stdlog type that satisfies the Logger interface // New returns a new Stdlog type that satisfies the Logger interface
func New(o io.Writer, l Level) *Stdlog { func New(output io.Writer, level Level) *Stdlog {
lf := log.Lmsgprefix | log.LstdFlags lf := log.Lmsgprefix | log.LstdFlags
return &Stdlog{ return &Stdlog{
l: l, level: level,
err: log.New(o, "ERROR: ", lf), err: log.New(output, "ERROR: ", lf),
warn: log.New(o, " WARN: ", lf), warn: log.New(output, " WARN: ", lf),
info: log.New(o, " INFO: ", lf), info: log.New(output, " INFO: ", lf),
debug: log.New(o, "DEBUG: ", lf), debug: log.New(output, "DEBUG: ", lf),
} }
} }
// Debugf performs a Printf() on the debug logger // Debugf performs a Printf() on the debug logger
func (l *Stdlog) Debugf(lo Log) { func (l *Stdlog) Debugf(log Log) {
if l.l >= LevelDebug { if l.level >= LevelDebug {
f := fmt.Sprintf("%s %s", lo.directionPrefix(), lo.Format) format := fmt.Sprintf("%s %s", log.directionPrefix(), log.Format)
_ = l.debug.Output(CallDepth, fmt.Sprintf(f, lo.Messages...)) _ = l.debug.Output(CallDepth, fmt.Sprintf(format, log.Messages...))
} }
} }
// Infof performs a Printf() on the info logger // Infof performs a Printf() on the info logger
func (l *Stdlog) Infof(lo Log) { func (l *Stdlog) Infof(log Log) {
if l.l >= LevelInfo { if l.level >= LevelInfo {
f := fmt.Sprintf("%s %s", lo.directionPrefix(), lo.Format) format := fmt.Sprintf("%s %s", log.directionPrefix(), log.Format)
_ = l.info.Output(CallDepth, fmt.Sprintf(f, lo.Messages...)) _ = l.info.Output(CallDepth, fmt.Sprintf(format, log.Messages...))
} }
} }
// Warnf performs a Printf() on the warn logger // Warnf performs a Printf() on the warn logger
func (l *Stdlog) Warnf(lo Log) { func (l *Stdlog) Warnf(log Log) {
if l.l >= LevelWarn { if l.level >= LevelWarn {
f := fmt.Sprintf("%s %s", lo.directionPrefix(), lo.Format) format := fmt.Sprintf("%s %s", log.directionPrefix(), log.Format)
_ = l.warn.Output(CallDepth, fmt.Sprintf(f, lo.Messages...)) _ = l.warn.Output(CallDepth, fmt.Sprintf(format, log.Messages...))
} }
} }
// Errorf performs a Printf() on the error logger // Errorf performs a Printf() on the error logger
func (l *Stdlog) Errorf(lo Log) { func (l *Stdlog) Errorf(log Log) {
if l.l >= LevelError { if l.level >= LevelError {
f := fmt.Sprintf("%s %s", lo.directionPrefix(), lo.Format) format := fmt.Sprintf("%s %s", log.directionPrefix(), log.Format)
_ = l.err.Output(CallDepth, fmt.Sprintf(f, lo.Messages...)) _ = l.err.Output(CallDepth, fmt.Sprintf(format, log.Messages...))
} }
} }

View file

@ -13,8 +13,8 @@ import (
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
var b bytes.Buffer var b bytes.Buffer
l := New(&b, LevelDebug) l := New(&b, LevelDebug)
if l.l != LevelDebug { if l.level != LevelDebug {
t.Error("Expected level to be LevelDebug, got ", l.l) t.Error("Expected level to be LevelDebug, got ", l.level)
} }
if l.err == nil || l.warn == nil || l.info == nil || l.debug == nil { if l.err == nil || l.warn == nil || l.info == nil || l.debug == nil {
t.Error("Loggers not initialized") t.Error("Loggers not initialized")
@ -37,7 +37,7 @@ func TestDebugf(t *testing.T) {
} }
b.Reset() b.Reset()
l.l = LevelInfo l.level = LevelInfo
l.Debugf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}}) l.Debugf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}})
if b.String() != "" { if b.String() != "" {
t.Error("Debug message was not expected to be logged") t.Error("Debug message was not expected to be logged")
@ -60,7 +60,7 @@ func TestInfof(t *testing.T) {
} }
b.Reset() b.Reset()
l.l = LevelWarn l.level = LevelWarn
l.Infof(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}}) l.Infof(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}})
if b.String() != "" { if b.String() != "" {
t.Error("Info message was not expected to be logged") t.Error("Info message was not expected to be logged")
@ -83,7 +83,7 @@ func TestWarnf(t *testing.T) {
} }
b.Reset() b.Reset()
l.l = LevelError l.level = LevelError
l.Warnf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}}) l.Warnf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}})
if b.String() != "" { if b.String() != "" {
t.Error("Warn message was not expected to be logged") t.Error("Warn message was not expected to be logged")
@ -106,7 +106,7 @@ func TestErrorf(t *testing.T) {
} }
b.Reset() b.Reset()
l.l = LevelError - 1 l.level = LevelError - 1
l.Errorf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}}) l.Errorf(Log{Direction: DirServerToClient, Format: "test %s", Messages: []interface{}{"foo"}})
if b.String() != "" { if b.String() != "" {
t.Error("Error message was not expected to be logged") t.Error("Error message was not expected to be logged")

764
msg.go

File diff suppressed because it is too large Load diff

View file

@ -1251,7 +1251,7 @@ func TestMsg_SetBodyString(t *testing.T) {
} }
part := m.parts[0] part := m.parts[0]
res := bytes.Buffer{} res := bytes.Buffer{}
if _, err := part.w(&res); err != nil && !tt.sf { if _, err := part.writeFunc(&res); err != nil && !tt.sf {
t.Errorf("WriteFunc of part failed: %s", err) t.Errorf("WriteFunc of part failed: %s", err)
} }
if res.String() != tt.want { if res.String() != tt.want {
@ -1286,7 +1286,7 @@ func TestMsg_AddAlternativeString(t *testing.T) {
} }
apart := m.parts[1] apart := m.parts[1]
res := bytes.Buffer{} res := bytes.Buffer{}
if _, err := apart.w(&res); err != nil && !tt.sf { if _, err := apart.writeFunc(&res); err != nil && !tt.sf {
t.Errorf("WriteFunc of part failed: %s", err) t.Errorf("WriteFunc of part failed: %s", err)
} }
if res.String() != tt.want { if res.String() != tt.want {
@ -3161,3 +3161,89 @@ func TestMsg_BccFromString(t *testing.T) {
}) })
} }
} }
// TestMsg_checkUserAgent tests the checkUserAgent method of the Msg
func TestMsg_checkUserAgent(t *testing.T) {
tests := []struct {
name string
noDefaultUserAgent bool
genHeader map[Header][]string
wantUserAgent string
sf bool
}{
{
name: "check default user agent",
noDefaultUserAgent: false,
wantUserAgent: fmt.Sprintf("go-mail v%s // https://github.com/wneessen/go-mail", VERSION),
sf: false,
},
{
name: "check no default user agent",
noDefaultUserAgent: true,
wantUserAgent: "",
sf: true,
},
{
name: "check if ua and xm is already set",
noDefaultUserAgent: false,
genHeader: map[Header][]string{
HeaderUserAgent: {"custom UA"},
HeaderXMailer: {"custom XM"},
},
wantUserAgent: "custom UA",
sf: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := &Msg{
noDefaultUserAgent: tt.noDefaultUserAgent,
genHeader: tt.genHeader,
}
msg.checkUserAgent()
gotUserAgent := ""
if val, ok := msg.genHeader[HeaderUserAgent]; ok {
gotUserAgent = val[0] // Assuming the first one is the needed value
}
if gotUserAgent != tt.wantUserAgent && !tt.sf {
t.Errorf("UserAgent got = %v, want = %v", gotUserAgent, tt.wantUserAgent)
}
})
}
}
// TestNewMsgWithMIMEVersion tests WithMIMEVersion and Msg.SetMIMEVersion
func TestNewMsgWithNoDefaultUserAgent(t *testing.T) {
m := NewMsg(WithNoDefaultUserAgent())
if m.noDefaultUserAgent != true {
t.Errorf("WithNoDefaultUserAgent() failed. Expected: %t, got: %t", true, false)
}
}
// Fuzzing tests
func FuzzMsg_Subject(f *testing.F) {
f.Add("Testsubject")
f.Add("")
f.Add("This is a longer test subject.")
f.Add("Let's add some umlauts: üäöß")
f.Add("Or even emojis: ☝️💪👍")
f.Fuzz(func(t *testing.T, data string) {
m := NewMsg()
m.Subject(data)
m.Reset()
})
}
func FuzzMsg_From(f *testing.F) {
f.Add("Toni Tester <toni@tester.com>")
f.Add("<tester@example.com>")
f.Add("mail@server.com")
f.Fuzz(func(t *testing.T, data string) {
m := NewMsg()
if err := m.From(data); err != nil &&
!strings.Contains(err.Error(), "failed to parse mail address") {
t.Errorf("failed set set FROM address: %s", err)
}
m.Reset()
})
}

View file

@ -35,237 +35,241 @@ const DoubleNewLine = "\r\n\r\n"
// msgWriter handles the I/O to the io.WriteCloser of the SMTP client // msgWriter handles the I/O to the io.WriteCloser of the SMTP client
type msgWriter struct { type msgWriter struct {
c Charset bytesWritten int64
d int8 charset Charset
en mime.WordEncoder depth int8
encoder mime.WordEncoder
err error err error
mpw [3]*multipart.Writer multiPartWriter [3]*multipart.Writer
n int64 partWriter io.Writer
pw io.Writer writer io.Writer
w io.Writer
} }
// Write implements the io.Writer interface for msgWriter // Write implements the io.Writer interface for msgWriter
func (mw *msgWriter) Write(p []byte) (int, error) { func (mw *msgWriter) Write(payload []byte) (int, error) {
if mw.err != nil { if mw.err != nil {
return 0, fmt.Errorf("failed to write due to previous error: %w", mw.err) return 0, fmt.Errorf("failed to write due to previous error: %w", mw.err)
} }
var n int var n int
n, mw.err = mw.w.Write(p) n, mw.err = mw.writer.Write(payload)
mw.n += int64(n) mw.bytesWritten += int64(n)
return n, mw.err return n, mw.err
} }
// writeMsg formats the message and sends it to its io.Writer // writeMsg formats the message and sends it to its io.Writer
func (mw *msgWriter) writeMsg(m *Msg) { func (mw *msgWriter) writeMsg(msg *Msg) {
m.addDefaultHeader() msg.addDefaultHeader()
m.checkUserAgent() msg.checkUserAgent()
mw.writeGenHeader(m) mw.writeGenHeader(msg)
mw.writePreformattedGenHeader(m) mw.writePreformattedGenHeader(msg)
// Set the FROM header (or envelope FROM if FROM is empty) // Set the FROM header (or envelope FROM if FROM is empty)
hf := true hasFrom := true
f, ok := m.addrHeader[HeaderFrom] from, ok := msg.addrHeader[HeaderFrom]
if !ok || (len(f) == 0 || f == nil) { if !ok || (len(from) == 0 || from == nil) {
f, ok = m.addrHeader[HeaderEnvelopeFrom] from, ok = msg.addrHeader[HeaderEnvelopeFrom]
if !ok || (len(f) == 0 || f == nil) { if !ok || (len(from) == 0 || from == nil) {
hf = false hasFrom = false
} }
} }
if hf && (len(f) > 0 && f[0] != nil) { if hasFrom && (len(from) > 0 && from[0] != nil) {
mw.writeHeader(Header(HeaderFrom), f[0].String()) mw.writeHeader(Header(HeaderFrom), from[0].String())
} }
// Set the rest of the address headers // Set the rest of the address headers
for _, t := range []AddrHeader{HeaderTo, HeaderCc} { for _, to := range []AddrHeader{HeaderTo, HeaderCc} {
if al, ok := m.addrHeader[t]; ok { if addresses, ok := msg.addrHeader[to]; ok {
var v []string var val []string
for _, a := range al { for _, addr := range addresses {
v = append(v, a.String()) val = append(val, addr.String())
} }
mw.writeHeader(Header(t), v...) mw.writeHeader(Header(to), val...)
} }
} }
if m.hasMixed() { if msg.hasMixed() {
mw.startMP("mixed", m.boundary) mw.startMP("mixed", msg.boundary)
mw.writeString(DoubleNewLine) mw.writeString(DoubleNewLine)
} }
if m.hasRelated() { if msg.hasRelated() {
mw.startMP("related", m.boundary) mw.startMP("related", msg.boundary)
mw.writeString(DoubleNewLine) mw.writeString(DoubleNewLine)
} }
if m.hasAlt() { if msg.hasAlt() {
mw.startMP(MIMEAlternative, m.boundary) mw.startMP(MIMEAlternative, msg.boundary)
mw.writeString(DoubleNewLine) mw.writeString(DoubleNewLine)
} }
if m.hasPGPType() { if msg.hasPGPType() {
switch m.pgptype { switch msg.pgptype {
case PGPEncrypt: case PGPEncrypt:
mw.startMP(`encrypted; protocol="application/pgp-encrypted"`, m.boundary) mw.startMP(`encrypted; protocol="application/pgp-encrypted"`,
msg.boundary)
case PGPSignature: case PGPSignature:
mw.startMP(`signed; protocol="application/pgp-signature";`, m.boundary) mw.startMP(`signed; protocol="application/pgp-signature";`,
msg.boundary)
default:
} }
mw.writeString(DoubleNewLine) mw.writeString(DoubleNewLine)
} }
for _, p := range m.parts { for _, part := range msg.parts {
if !p.del { if !part.isDeleted {
mw.writePart(p, m.charset) mw.writePart(part, msg.charset)
} }
} }
if m.hasAlt() { if msg.hasAlt() {
mw.stopMP() mw.stopMP()
} }
// Add embeds // Add embeds
mw.addFiles(m.embeds, false) mw.addFiles(msg.embeds, false)
if m.hasRelated() { if msg.hasRelated() {
mw.stopMP() mw.stopMP()
} }
// Add attachments // Add attachments
mw.addFiles(m.attachments, true) mw.addFiles(msg.attachments, true)
if m.hasMixed() { if msg.hasMixed() {
mw.stopMP() mw.stopMP()
} }
} }
// writeGenHeader writes out all generic headers to the msgWriter // writeGenHeader writes out all generic headers to the msgWriter
func (mw *msgWriter) writeGenHeader(m *Msg) { func (mw *msgWriter) writeGenHeader(msg *Msg) {
gk := make([]string, 0, len(m.genHeader)) keys := make([]string, 0, len(msg.genHeader))
for k := range m.genHeader { for key := range msg.genHeader {
gk = append(gk, string(k)) keys = append(keys, string(key))
} }
sort.Strings(gk) sort.Strings(keys)
for _, k := range gk { for _, key := range keys {
mw.writeHeader(Header(k), m.genHeader[Header(k)]...) mw.writeHeader(Header(key), msg.genHeader[Header(key)]...)
} }
} }
// writePreformatedHeader writes out all preformated generic headers to the msgWriter // writePreformatedHeader writes out all preformated generic headers to the msgWriter
func (mw *msgWriter) writePreformattedGenHeader(m *Msg) { func (mw *msgWriter) writePreformattedGenHeader(msg *Msg) {
for k, v := range m.preformHeader { for key, val := range msg.preformHeader {
mw.writeString(fmt.Sprintf("%s: %s%s", k, v, SingleNewLine)) mw.writeString(fmt.Sprintf("%s: %s%s", key, val, SingleNewLine))
} }
} }
// startMP writes a multipart beginning // startMP writes a multipart beginning
func (mw *msgWriter) startMP(mt MIMEType, b string) { func (mw *msgWriter) startMP(mimeType MIMEType, boundary string) {
mp := multipart.NewWriter(mw) multiPartWriter := multipart.NewWriter(mw)
if b != "" { if boundary != "" {
mw.err = mp.SetBoundary(b) mw.err = multiPartWriter.SetBoundary(boundary)
} }
ct := fmt.Sprintf("multipart/%s;\r\n boundary=%s", mt, mp.Boundary()) contentType := fmt.Sprintf("multipart/%s;\r\n boundary=%s", mimeType,
mw.mpw[mw.d] = mp multiPartWriter.Boundary())
mw.multiPartWriter[mw.depth] = multiPartWriter
if mw.d == 0 { if mw.depth == 0 {
mw.writeString(fmt.Sprintf("%s: %s", HeaderContentType, ct)) mw.writeString(fmt.Sprintf("%s: %s", HeaderContentType, contentType))
} }
if mw.d > 0 { if mw.depth > 0 {
mw.newPart(map[string][]string{"Content-Type": {ct}}) mw.newPart(map[string][]string{"Content-Type": {contentType}})
} }
mw.d++ mw.depth++
} }
// stopMP closes the multipart // stopMP closes the multipart
func (mw *msgWriter) stopMP() { func (mw *msgWriter) stopMP() {
if mw.d > 0 { if mw.depth > 0 {
mw.err = mw.mpw[mw.d-1].Close() mw.err = mw.multiPartWriter[mw.depth-1].Close()
mw.d-- mw.depth--
} }
} }
// addFiles adds the attachments/embeds file content to the mail body // addFiles adds the attachments/embeds file content to the mail body
func (mw *msgWriter) addFiles(fl []*File, a bool) { func (mw *msgWriter) addFiles(files []*File, isAttachment bool) {
for _, f := range fl { for _, file := range files {
e := EncodingB64 encoding := EncodingB64
if _, ok := f.getHeader(HeaderContentType); !ok { if _, ok := file.getHeader(HeaderContentType); !ok {
mt := mime.TypeByExtension(filepath.Ext(f.Name)) mimeType := mime.TypeByExtension(filepath.Ext(file.Name))
if mt == "" { if mimeType == "" {
mt = "application/octet-stream" mimeType = "application/octet-stream"
} }
if f.ContentType != "" { if file.ContentType != "" {
mt = string(f.ContentType) mimeType = string(file.ContentType)
} }
f.setHeader(HeaderContentType, fmt.Sprintf(`%s; name="%s"`, mt, file.setHeader(HeaderContentType, fmt.Sprintf(`%s; name="%s"`, mimeType,
mw.en.Encode(mw.c.String(), f.Name))) mw.encoder.Encode(mw.charset.String(), file.Name)))
} }
if _, ok := f.getHeader(HeaderContentTransferEnc); !ok { if _, ok := file.getHeader(HeaderContentTransferEnc); !ok {
if f.Enc != "" { if file.Enc != "" {
e = f.Enc encoding = file.Enc
} }
f.setHeader(HeaderContentTransferEnc, string(e)) file.setHeader(HeaderContentTransferEnc, string(encoding))
} }
if f.Desc != "" { if file.Desc != "" {
if _, ok := f.getHeader(HeaderContentDescription); !ok { if _, ok := file.getHeader(HeaderContentDescription); !ok {
f.setHeader(HeaderContentDescription, f.Desc) file.setHeader(HeaderContentDescription, file.Desc)
} }
} }
if _, ok := f.getHeader(HeaderContentDisposition); !ok { if _, ok := file.getHeader(HeaderContentDisposition); !ok {
d := "inline" disposition := "inline"
if a { if isAttachment {
d = "attachment" disposition = "attachment"
} }
f.setHeader(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, d, file.setHeader(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`,
mw.en.Encode(mw.c.String(), f.Name))) disposition, mw.encoder.Encode(mw.charset.String(), file.Name)))
} }
if !a { if !isAttachment {
if _, ok := f.getHeader(HeaderContentID); !ok { if _, ok := file.getHeader(HeaderContentID); !ok {
f.setHeader(HeaderContentID, fmt.Sprintf("<%s>", f.Name)) file.setHeader(HeaderContentID, fmt.Sprintf("<%s>", file.Name))
} }
} }
if mw.d == 0 { if mw.depth == 0 {
for h, v := range f.Header { for header, val := range file.Header {
mw.writeHeader(Header(h), v...) mw.writeHeader(Header(header), val...)
} }
mw.writeString(SingleNewLine) mw.writeString(SingleNewLine)
} }
if mw.d > 0 { if mw.depth > 0 {
mw.newPart(f.Header) mw.newPart(file.Header)
} }
if mw.err == nil { if mw.err == nil {
mw.writeBody(f.Writer, e) mw.writeBody(file.Writer, encoding)
} }
} }
} }
// newPart creates a new MIME multipart io.Writer and sets the partwriter to it // newPart creates a new MIME multipart io.Writer and sets the partwriter to it
func (mw *msgWriter) newPart(h map[string][]string) { func (mw *msgWriter) newPart(header map[string][]string) {
mw.pw, mw.err = mw.mpw[mw.d-1].CreatePart(h) mw.partWriter, mw.err = mw.multiPartWriter[mw.depth-1].CreatePart(header)
} }
// writePart writes the corresponding part to the Msg body // writePart writes the corresponding part to the Msg body
func (mw *msgWriter) writePart(p *Part, cs Charset) { func (mw *msgWriter) writePart(part *Part, charset Charset) {
pcs := p.cset partCharset := part.charset
if pcs.String() == "" { if partCharset.String() == "" {
pcs = cs partCharset = charset
} }
ct := fmt.Sprintf("%s; charset=%s", p.ctype, pcs) contentType := fmt.Sprintf("%s; charset=%s", part.contentType, partCharset)
cte := p.enc.String() contentTransferEnc := part.encoding.String()
if mw.d == 0 { if mw.depth == 0 {
mw.writeHeader(HeaderContentType, ct) mw.writeHeader(HeaderContentType, contentType)
mw.writeHeader(HeaderContentTransferEnc, cte) mw.writeHeader(HeaderContentTransferEnc, contentTransferEnc)
mw.writeString(SingleNewLine) mw.writeString(SingleNewLine)
} }
if mw.d > 0 { if mw.depth > 0 {
mh := textproto.MIMEHeader{} mimeHeader := textproto.MIMEHeader{}
if p.desc != "" { if part.description != "" {
mh.Add(string(HeaderContentDescription), p.desc) mimeHeader.Add(string(HeaderContentDescription), part.description)
} }
mh.Add(string(HeaderContentType), ct) mimeHeader.Add(string(HeaderContentType), contentType)
mh.Add(string(HeaderContentTransferEnc), cte) mimeHeader.Add(string(HeaderContentTransferEnc), contentTransferEnc)
mw.newPart(mh) mw.newPart(mimeHeader)
} }
mw.writeBody(p.w, p.enc) mw.writeBody(part.writeFunc, part.encoding)
} }
// writeString writes a string into the msgWriter's io.Writer interface // writeString writes a string into the msgWriter's io.Writer interface
@ -274,102 +278,103 @@ func (mw *msgWriter) writeString(s string) {
return return
} }
var n int var n int
n, mw.err = io.WriteString(mw.w, s) n, mw.err = io.WriteString(mw.writer, s)
mw.n += int64(n) mw.bytesWritten += int64(n)
} }
// writeHeader writes a header into the msgWriter's io.Writer // writeHeader writes a header into the msgWriter's io.Writer
func (mw *msgWriter) writeHeader(k Header, vl ...string) { func (mw *msgWriter) writeHeader(key Header, values ...string) {
wbuf := bytes.Buffer{} buffer := strings.Builder{}
cl := MaxHeaderLength - 2 charLength := MaxHeaderLength - 2
wbuf.WriteString(string(k)) buffer.WriteString(string(key))
cl -= len(k) charLength -= len(key)
if len(vl) == 0 { if len(values) == 0 {
wbuf.WriteString(":\r\n") buffer.WriteString(":\r\n")
return return
} }
wbuf.WriteString(": ") buffer.WriteString(": ")
cl -= 2 charLength -= 2
fs := strings.Join(vl, ", ") fullValueStr := strings.Join(values, ", ")
sfs := strings.Split(fs, " ") words := strings.Split(fullValueStr, " ")
for i, v := range sfs { for i, val := range words {
if cl-len(v) <= 1 { if charLength-len(val) <= 1 {
wbuf.WriteString(fmt.Sprintf("%s ", SingleNewLine)) buffer.WriteString(fmt.Sprintf("%s ", SingleNewLine))
cl = MaxHeaderLength - 3 charLength = MaxHeaderLength - 3
} }
wbuf.WriteString(v) buffer.WriteString(val)
if i < len(sfs)-1 { if i < len(words)-1 {
wbuf.WriteString(" ") buffer.WriteString(" ")
cl -= 1 charLength -= 1
} }
cl -= len(v) charLength -= len(val)
} }
bufs := wbuf.String() bufferString := buffer.String()
bufs = strings.ReplaceAll(bufs, fmt.Sprintf(" %s", SingleNewLine), SingleNewLine) bufferString = strings.ReplaceAll(bufferString, fmt.Sprintf(" %s", SingleNewLine),
mw.writeString(bufs) SingleNewLine)
mw.writeString(bufferString)
mw.writeString("\r\n") mw.writeString("\r\n")
} }
// writeBody writes an io.Reader into an io.Writer using provided Encoding // writeBody writes an io.Reader into an io.Writer using provided Encoding
func (mw *msgWriter) writeBody(f func(io.Writer) (int64, error), e Encoding) { func (mw *msgWriter) writeBody(writeFunc func(io.Writer) (int64, error), encoding Encoding) {
var w io.Writer var writer io.Writer
var ew io.WriteCloser var encodedWriter io.WriteCloser
var n int64 var n int64
var err error var err error
if mw.d == 0 { if mw.depth == 0 {
w = mw.w writer = mw.writer
} }
if mw.d > 0 { if mw.depth > 0 {
w = mw.pw writer = mw.partWriter
} }
wbuf := bytes.Buffer{} writeBuffer := bytes.Buffer{}
lb := Base64LineBreaker{} lineBreaker := Base64LineBreaker{}
lb.out = &wbuf lineBreaker.out = &writeBuffer
switch e { switch encoding {
case EncodingQP: case EncodingQP:
ew = quotedprintable.NewWriter(&wbuf) encodedWriter = quotedprintable.NewWriter(&writeBuffer)
case EncodingB64: case EncodingB64:
ew = base64.NewEncoder(base64.StdEncoding, &lb) encodedWriter = base64.NewEncoder(base64.StdEncoding, &lineBreaker)
case NoEncoding: case NoEncoding:
_, err = f(&wbuf) _, err = writeFunc(&writeBuffer)
if err != nil { if err != nil {
mw.err = fmt.Errorf("bodyWriter function: %w", err) mw.err = fmt.Errorf("bodyWriter function: %w", err)
} }
n, err = io.Copy(w, &wbuf) n, err = io.Copy(writer, &writeBuffer)
if err != nil && mw.err == nil { if err != nil && mw.err == nil {
mw.err = fmt.Errorf("bodyWriter io.Copy: %w", err) mw.err = fmt.Errorf("bodyWriter io.Copy: %w", err)
} }
if mw.d == 0 { if mw.depth == 0 {
mw.n += n mw.bytesWritten += n
} }
return return
default: default:
ew = quotedprintable.NewWriter(w) encodedWriter = quotedprintable.NewWriter(writer)
} }
_, err = f(ew) _, err = writeFunc(encodedWriter)
if err != nil { if err != nil {
mw.err = fmt.Errorf("bodyWriter function: %w", err) mw.err = fmt.Errorf("bodyWriter function: %w", err)
} }
err = ew.Close() err = encodedWriter.Close()
if err != nil && mw.err == nil { if err != nil && mw.err == nil {
mw.err = fmt.Errorf("bodyWriter close encoded writer: %w", err) mw.err = fmt.Errorf("bodyWriter close encoded writer: %w", err)
} }
err = lb.Close() err = lineBreaker.Close()
if err != nil && mw.err == nil { if err != nil && mw.err == nil {
mw.err = fmt.Errorf("bodyWriter close linebreaker: %w", err) mw.err = fmt.Errorf("bodyWriter close linebreaker: %w", err)
} }
n, err = io.Copy(w, &wbuf) n, err = io.Copy(writer, &writeBuffer)
if err != nil && mw.err == nil { if err != nil && mw.err == nil {
mw.err = fmt.Errorf("bodyWriter io.Copy: %w", err) mw.err = fmt.Errorf("bodyWriter io.Copy: %w", err)
} }
// Since the part writer uses the WriteTo() method, we don't need to add the // Since the part writer uses the WriteTo() method, we don't need to add the
// bytes twice // bytes twice
if mw.d == 0 { if mw.depth == 0 {
mw.n += n mw.bytesWritten += n
} }
} }

View file

@ -28,7 +28,7 @@ func (bw *brokenWriter) Write([]byte) (int, error) {
// TestMsgWriter_Write tests the WriteTo() method of the msgWriter // TestMsgWriter_Write tests the WriteTo() method of the msgWriter
func TestMsgWriter_Write(t *testing.T) { func TestMsgWriter_Write(t *testing.T) {
bw := &brokenWriter{} bw := &brokenWriter{}
mw := &msgWriter{w: bw, c: CharsetUTF8, en: mime.QEncoding} mw := &msgWriter{writer: bw, charset: CharsetUTF8, encoder: mime.QEncoding}
_, err := mw.Write([]byte("test")) _, err := mw.Write([]byte("test"))
if err == nil { if err == nil {
t.Errorf("msgWriter WriteTo() with brokenWriter should fail, but didn't") t.Errorf("msgWriter WriteTo() with brokenWriter should fail, but didn't")
@ -55,7 +55,7 @@ func TestMsgWriter_writeMsg(t *testing.T) {
m.SetBodyString(TypeTextPlain, "This is the body") m.SetBodyString(TypeTextPlain, "This is the body")
m.AddAlternativeString(TypeTextHTML, "This is the alternative body") m.AddAlternativeString(TypeTextHTML, "This is the alternative body")
buf := bytes.Buffer{} buf := bytes.Buffer{}
mw := &msgWriter{w: &buf, c: CharsetUTF8, en: mime.QEncoding} mw := &msgWriter{writer: &buf, charset: CharsetUTF8, encoder: mime.QEncoding}
mw.writeMsg(m) mw.writeMsg(m)
ms := buf.String() ms := buf.String()
@ -134,7 +134,7 @@ func TestMsgWriter_writeMsg_PGP(t *testing.T) {
m.Subject("This is a subject") m.Subject("This is a subject")
m.SetBodyString(TypeTextPlain, "This is the body") m.SetBodyString(TypeTextPlain, "This is the body")
buf := bytes.Buffer{} buf := bytes.Buffer{}
mw := &msgWriter{w: &buf, c: CharsetUTF8, en: mime.QEncoding} mw := &msgWriter{writer: &buf, charset: CharsetUTF8, encoder: mime.QEncoding}
mw.writeMsg(m) mw.writeMsg(m)
ms := buf.String() ms := buf.String()
if !strings.Contains(ms, `encrypted; protocol="application/pgp-encrypted"`) { if !strings.Contains(ms, `encrypted; protocol="application/pgp-encrypted"`) {
@ -147,7 +147,7 @@ func TestMsgWriter_writeMsg_PGP(t *testing.T) {
m.Subject("This is a subject") m.Subject("This is a subject")
m.SetBodyString(TypeTextPlain, "This is the body") m.SetBodyString(TypeTextPlain, "This is the body")
buf = bytes.Buffer{} buf = bytes.Buffer{}
mw = &msgWriter{w: &buf, c: CharsetUTF8, en: mime.QEncoding} mw = &msgWriter{writer: &buf, charset: CharsetUTF8, encoder: mime.QEncoding}
mw.writeMsg(m) mw.writeMsg(m)
ms = buf.String() ms = buf.String()
if !strings.Contains(ms, `signed; protocol="application/pgp-signature"`) { if !strings.Contains(ms, `signed; protocol="application/pgp-signature"`) {

66
part.go
View file

@ -14,18 +14,18 @@ type PartOption func(*Part)
// Part is a part of the Msg // Part is a part of the Msg
type Part struct { type Part struct {
ctype ContentType contentType ContentType
cset Charset charset Charset
desc string description string
enc Encoding encoding Encoding
del bool isDeleted bool
w func(io.Writer) (int64, error) writeFunc func(io.Writer) (int64, error)
} }
// GetContent executes the WriteFunc of the Part and returns the content as byte slice // GetContent executes the WriteFunc of the Part and returns the content as byte slice
func (p *Part) GetContent() ([]byte, error) { func (p *Part) GetContent() ([]byte, error) {
var b bytes.Buffer var b bytes.Buffer
if _, err := p.w(&b); err != nil { if _, err := p.writeFunc(&b); err != nil {
return nil, err return nil, err
} }
return b.Bytes(), nil return b.Bytes(), nil
@ -33,83 +33,83 @@ func (p *Part) GetContent() ([]byte, error) {
// GetCharset returns the currently set Charset of the Part // GetCharset returns the currently set Charset of the Part
func (p *Part) GetCharset() Charset { func (p *Part) GetCharset() Charset {
return p.cset return p.charset
} }
// GetContentType returns the currently set ContentType of the Part // GetContentType returns the currently set ContentType of the Part
func (p *Part) GetContentType() ContentType { func (p *Part) GetContentType() ContentType {
return p.ctype return p.contentType
} }
// GetEncoding returns the currently set Encoding of the Part // GetEncoding returns the currently set Encoding of the Part
func (p *Part) GetEncoding() Encoding { func (p *Part) GetEncoding() Encoding {
return p.enc return p.encoding
} }
// GetWriteFunc returns the currently set WriterFunc of the Part // GetWriteFunc returns the currently set WriterFunc of the Part
func (p *Part) GetWriteFunc() func(io.Writer) (int64, error) { func (p *Part) GetWriteFunc() func(io.Writer) (int64, error) {
return p.w return p.writeFunc
} }
// GetDescription returns the currently set Content-Description of the Part // GetDescription returns the currently set Content-Description of the Part
func (p *Part) GetDescription() string { func (p *Part) GetDescription() string {
return p.desc return p.description
} }
// SetContent overrides the content of the Part with the given string // SetContent overrides the content of the Part with the given string
func (p *Part) SetContent(c string) { func (p *Part) SetContent(content string) {
buf := bytes.NewBufferString(c) buffer := bytes.NewBufferString(content)
p.w = writeFuncFromBuffer(buf) p.writeFunc = writeFuncFromBuffer(buffer)
} }
// SetContentType overrides the ContentType of the Part // SetContentType overrides the ContentType of the Part
func (p *Part) SetContentType(c ContentType) { func (p *Part) SetContentType(contentType ContentType) {
p.ctype = c p.contentType = contentType
} }
// SetCharset overrides the Charset of the Part // SetCharset overrides the Charset of the Part
func (p *Part) SetCharset(c Charset) { func (p *Part) SetCharset(charset Charset) {
p.cset = c p.charset = charset
} }
// SetEncoding creates a new mime.WordEncoder based on the encoding setting of the message // SetEncoding creates a new mime.WordEncoder based on the encoding setting of the message
func (p *Part) SetEncoding(e Encoding) { func (p *Part) SetEncoding(encoding Encoding) {
p.enc = e p.encoding = encoding
} }
// SetDescription overrides the Content-Description of the Part // SetDescription overrides the Content-Description of the Part
func (p *Part) SetDescription(d string) { func (p *Part) SetDescription(description string) {
p.desc = d p.description = description
} }
// SetWriteFunc overrides the WriteFunc of the Part // SetWriteFunc overrides the WriteFunc of the Part
func (p *Part) SetWriteFunc(w func(io.Writer) (int64, error)) { func (p *Part) SetWriteFunc(writeFunc func(io.Writer) (int64, error)) {
p.w = w p.writeFunc = writeFunc
} }
// Delete removes the current part from the parts list of the Msg by setting the // Delete removes the current part from the parts list of the Msg by setting the
// del flag to true. The msgWriter will skip it then // isDeleted flag to true. The msgWriter will skip it then
func (p *Part) Delete() { func (p *Part) Delete() {
p.del = true p.isDeleted = true
} }
// WithPartCharset overrides the default Part charset // WithPartCharset overrides the default Part charset
func WithPartCharset(c Charset) PartOption { func WithPartCharset(charset Charset) PartOption {
return func(p *Part) { return func(p *Part) {
p.cset = c p.charset = charset
} }
} }
// WithPartEncoding overrides the default Part encoding // WithPartEncoding overrides the default Part encoding
func WithPartEncoding(e Encoding) PartOption { func WithPartEncoding(encoding Encoding) PartOption {
return func(p *Part) { return func(p *Part) {
p.enc = e p.encoding = encoding
} }
} }
// WithPartContentDescription overrides the default Part Content-Description // WithPartContentDescription overrides the default Part Content-Description
func WithPartContentDescription(d string) PartOption { func WithPartContentDescription(description string) PartOption {
return func(p *Part) { return func(p *Part) {
p.desc = d p.description = description
} }
} }

View file

@ -31,15 +31,15 @@ func TestPartEncoding(t *testing.T) {
t.Errorf("newPart() WithPartEncoding() failed: no part returned") t.Errorf("newPart() WithPartEncoding() failed: no part returned")
return return
} }
if part.enc.String() != tt.want { if part.encoding.String() != tt.want {
t.Errorf("newPart() WithPartEncoding() failed: expected encoding: %s, got: %s", tt.want, t.Errorf("newPart() WithPartEncoding() failed: expected encoding: %s, got: %s", tt.want,
part.enc.String()) part.encoding.String())
} }
part.enc = "" part.encoding = ""
part.SetEncoding(tt.enc) part.SetEncoding(tt.enc)
if part.enc.String() != tt.want { if part.encoding.String() != tt.want {
t.Errorf("newPart() SetEncoding() failed: expected encoding: %s, got: %s", tt.want, t.Errorf("newPart() SetEncoding() failed: expected encoding: %s, got: %s", tt.want,
part.enc.String()) part.encoding.String())
} }
}) })
} }
@ -64,9 +64,9 @@ func TestWithPartCharset(t *testing.T) {
t.Errorf("newPart() WithPartCharset() failed: no part returned") t.Errorf("newPart() WithPartCharset() failed: no part returned")
return return
} }
if part.cset.String() != tt.want { if part.charset.String() != tt.want {
t.Errorf("newPart() WithPartCharset() failed: expected charset: %s, got: %s", t.Errorf("newPart() WithPartCharset() failed: expected charset: %s, got: %s",
tt.want, part.cset.String()) tt.want, part.charset.String())
} }
}) })
} }
@ -89,14 +89,14 @@ func TestPart_WithPartContentDescription(t *testing.T) {
t.Errorf("newPart() WithPartContentDescription() failed: no part returned") t.Errorf("newPart() WithPartContentDescription() failed: no part returned")
return return
} }
if part.desc != tt.desc { if part.description != tt.desc {
t.Errorf("newPart() WithPartContentDescription() failed: expected: %s, got: %s", tt.desc, t.Errorf("newPart() WithPartContentDescription() failed: expected: %s, got: %s", tt.desc,
part.desc) part.description)
} }
part.desc = "" part.description = ""
part.SetDescription(tt.desc) part.SetDescription(tt.desc)
if part.desc != tt.desc { if part.description != tt.desc {
t.Errorf("newPart() SetDescription() failed: expected: %s, got: %s", tt.desc, part.desc) t.Errorf("newPart() SetDescription() failed: expected: %s, got: %s", tt.desc, part.description)
} }
}) })
} }
@ -236,7 +236,7 @@ func TestPart_GetContentBroken(t *testing.T) {
t.Errorf("failed: %s", err) t.Errorf("failed: %s", err)
return return
} }
pl[0].w = func(io.Writer) (int64, error) { pl[0].writeFunc = func(io.Writer) (int64, error) {
return 0, fmt.Errorf("broken") return 0, fmt.Errorf("broken")
} }
_, err = pl[0].GetContent() _, err = pl[0].GetContent()
@ -314,8 +314,8 @@ func TestPart_SetDescription(t *testing.T) {
t.Errorf("Part.GetDescription failed. Expected empty description but got: %s", pd) t.Errorf("Part.GetDescription failed. Expected empty description but got: %s", pd)
} }
pl[0].SetDescription(d) pl[0].SetDescription(d)
if pl[0].desc != d { if pl[0].description != d {
t.Errorf("Part.SetDescription failed. Expected desc to be: %s, got: %s", d, pd) t.Errorf("Part.SetDescription failed. Expected description to be: %s, got: %s", d, pd)
} }
pd = pl[0].GetDescription() pd = pl[0].GetDescription()
if pd != d { if pd != d {
@ -334,8 +334,8 @@ func TestPart_Delete(t *testing.T) {
return return
} }
pl[0].Delete() pl[0].Delete()
if !pl[0].del { if !pl[0].isDeleted {
t.Errorf("Delete failed. Expected: %t, got: %t", true, pl[0].del) t.Errorf("Delete failed. Expected: %t, got: %t", true, pl[0].isDeleted)
} }
} }

View file

@ -22,54 +22,53 @@ const (
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
) )
// randomStringSecure returns a random, n long string of characters. The character set is based // randomStringSecure returns a random, string of length characters. This method uses the
// on the s (special chars) and h (human readable) boolean arguments. This method uses the
// crypto/random package and therfore is cryptographically secure // crypto/random package and therfore is cryptographically secure
func randomStringSecure(n int) (string, error) { func randomStringSecure(length int) (string, error) {
rs := strings.Builder{} randString := strings.Builder{}
rs.Grow(n) randString.Grow(length)
crl := len(cr) charRangeLength := len(cr)
rp := make([]byte, 8) randPool := make([]byte, 8)
_, err := rand.Read(rp) _, err := rand.Read(randPool)
if err != nil { if err != nil {
return rs.String(), err return randString.String(), err
} }
for i, c, r := n-1, binary.BigEndian.Uint64(rp), letterIdxMax; i >= 0; { for idx, char, rest := length-1, binary.BigEndian.Uint64(randPool), letterIdxMax; idx >= 0; {
if r == 0 { if rest == 0 {
_, err := rand.Read(rp) _, err = rand.Read(randPool)
if err != nil { if err != nil {
return rs.String(), err return randString.String(), err
} }
c, r = binary.BigEndian.Uint64(rp), letterIdxMax char, rest = binary.BigEndian.Uint64(randPool), letterIdxMax
} }
if idx := int(c & letterIdxMask); idx < crl { if i := int(char & letterIdxMask); i < charRangeLength {
rs.WriteByte(cr[idx]) randString.WriteByte(cr[i])
i-- idx--
} }
c >>= letterIdxBits char >>= letterIdxBits
r-- rest--
} }
return rs.String(), nil return randString.String(), nil
} }
// randNum returns a random number with a maximum value of n // randNum returns a random number with a maximum value of length
func randNum(n int) (int, error) { func randNum(length int) (int, error) {
if n <= 0 { if length <= 0 {
return 0, fmt.Errorf("provided number is <= 0: %d", n) return 0, fmt.Errorf("provided number is <= 0: %d", length)
} }
mbi := big.NewInt(int64(n)) length64 := big.NewInt(int64(length))
if !mbi.IsUint64() { if !length64.IsUint64() {
return 0, fmt.Errorf("big.NewInt() generation returned negative value: %d", mbi) return 0, fmt.Errorf("big.NewInt() generation returned negative value: %d", length64)
} }
rn64, err := rand.Int(rand.Reader, mbi) randNum64, err := rand.Int(rand.Reader, length64)
if err != nil { if err != nil {
return 0, err return 0, err
} }
rn := int(rn64.Int64()) randomNum := int(randNum64.Int64())
if rn < 0 { if randomNum < 0 {
return 0, fmt.Errorf("generated random number does not fit as int64: %d", rn64) return 0, fmt.Errorf("generated random number does not fit as int64: %d", randNum64)
} }
return rn, nil return randomNum, nil
} }

View file

@ -10,8 +10,8 @@ import (
// Reader is a type that implements the io.Reader interface for a Msg // Reader is a type that implements the io.Reader interface for a Msg
type Reader struct { type Reader struct {
buf []byte // contents are the bytes buf[off : len(buf)] buffer []byte // contents are the bytes buffer[offset : len(buffer)]
off int // read at &buf[off], write at &buf[len(buf)] offset int // read at &buffer[offset], write at &buffer[len(buffer)]
err error // initialization error err error // initialization error
} }
@ -21,28 +21,28 @@ func (r *Reader) Error() error {
} }
// Read reads the length of p of the Msg buffer to satisfy the io.Reader interface // Read reads the length of p of the Msg buffer to satisfy the io.Reader interface
func (r *Reader) Read(p []byte) (n int, err error) { func (r *Reader) Read(payload []byte) (n int, err error) {
if r.err != nil { if r.err != nil {
return 0, r.err return 0, r.err
} }
if r.empty() || r.buf == nil { if r.empty() || r.buffer == nil {
r.Reset() r.Reset()
if len(p) == 0 { if len(payload) == 0 {
return 0, nil return 0, nil
} }
return 0, io.EOF return 0, io.EOF
} }
n = copy(p, r.buf[r.off:]) n = copy(payload, r.buffer[r.offset:])
r.off += n r.offset += n
return n, err return n, err
} }
// Reset resets the Reader buffer to be empty, but it retains the underlying storage // Reset resets the Reader buffer to be empty, but it retains the underlying storage
// for use by future writes. // for use by future writes.
func (r *Reader) Reset() { func (r *Reader) Reset() {
r.buf = r.buf[:0] r.buffer = r.buffer[:0]
r.off = 0 r.offset = 0
} }
// empty reports whether the unread portion of the Reader buffer is empty. // empty reports whether the unread portion of the Reader buffer is empty.
func (r *Reader) empty() bool { return len(r.buf) <= r.off } func (r *Reader) empty() bool { return len(r.buffer) <= r.offset }

View file

@ -65,7 +65,7 @@ func TestReader_Read_error(t *testing.T) {
// TestReader_Read_empty tests the Reader.Read method with an empty buffer // TestReader_Read_empty tests the Reader.Read method with an empty buffer
func TestReader_Read_empty(t *testing.T) { func TestReader_Read_empty(t *testing.T) {
r := Reader{buf: []byte{}} r := Reader{buffer: []byte{}}
p := make([]byte, 1) p := make([]byte, 1)
p[0] = 'a' p[0] = 'a'
_, err := r.Read(p) _, err := r.Read(p)
@ -76,7 +76,7 @@ func TestReader_Read_empty(t *testing.T) {
// TestReader_Read_nil tests the Reader.Read method with a nil buffer // TestReader_Read_nil tests the Reader.Read method with a nil buffer
func TestReader_Read_nil(t *testing.T) { func TestReader_Read_nil(t *testing.T) {
r := Reader{buf: nil, off: -10} r := Reader{buffer: nil, offset: -10}
p := make([]byte, 0) p := make([]byte, 0)
_, err := r.Read(p) _, err := r.Read(p)
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {

View file

@ -71,34 +71,34 @@ func (e *SendError) Error() string {
return "unknown reason" return "unknown reason"
} }
var em strings.Builder var errMessage strings.Builder
em.WriteString(e.Reason.String()) errMessage.WriteString(e.Reason.String())
if len(e.errlist) > 0 { if len(e.errlist) > 0 {
em.WriteRune(':') errMessage.WriteRune(':')
for i := range e.errlist { for i := range e.errlist {
em.WriteRune(' ') errMessage.WriteRune(' ')
em.WriteString(e.errlist[i].Error()) errMessage.WriteString(e.errlist[i].Error())
if i != len(e.errlist)-1 { if i != len(e.errlist)-1 {
em.WriteString(", ") errMessage.WriteString(", ")
} }
} }
} }
if len(e.rcpt) > 0 { if len(e.rcpt) > 0 {
em.WriteString(", affected recipient(s): ") errMessage.WriteString(", affected recipient(s): ")
for i := range e.rcpt { for i := range e.rcpt {
em.WriteString(e.rcpt[i]) errMessage.WriteString(e.rcpt[i])
if i != len(e.rcpt)-1 { if i != len(e.rcpt)-1 {
em.WriteString(", ") errMessage.WriteString(", ")
} }
} }
} }
return em.String() return errMessage.String()
} }
// Is implements the errors.Is functionality and compares the SendErrReason // Is implements the errors.Is functionality and compares the SendErrReason
func (e *SendError) Is(et error) bool { func (e *SendError) Is(errType error) bool {
var t *SendError var t *SendError
if errors.As(et, &t) && t != nil { if errors.As(errType, &t) && t != nil {
return e.Reason == t.Reason && e.isTemp == t.isTemp return e.Reason == t.Reason && e.isTemp == t.isTemp
} }
return false return false
@ -143,6 +143,6 @@ func (r SendErrReason) String() string {
// isTempError checks the given SMTP error and returns true if the given error is of temporary nature // isTempError checks the given SMTP error and returns true if the given error is of temporary nature
// and should be retried // and should be retried
func isTempError(e error) bool { func isTempError(err error) bool {
return e.Error()[0] == '4' return err.Error()[0] == '4'
} }

View file

@ -777,9 +777,9 @@ func TestClient_SetLogger(t *testing.T) {
if c.logger == nil { if c.logger == nil {
t.Errorf("Expected Logger to be set but received nil") t.Errorf("Expected Logger to be set but received nil")
} }
c.logger.Debugf(log.Log{Direction: log.DirServerToClient, Format: "", Messages: []interface{}{"test"}}) c.logger.Debugf(log.Log{Direction: log.DirServerToClient, Format: "%s", Messages: []interface{}{"test"}})
c.SetLogger(nil) c.SetLogger(nil)
c.logger.Debugf(log.Log{Direction: log.DirServerToClient, Format: "", Messages: []interface{}{"test"}}) c.logger.Debugf(log.Log{Direction: log.DirServerToClient, Format: "%s", Messages: []interface{}{"test"}})
} }
var newClientServer = `220 hello world var newClientServer = `220 hello world

8
tls.go
View file

@ -8,17 +8,17 @@ package mail
type TLSPolicy int type TLSPolicy int
const ( const (
// TLSMandatory requires that the connection cto the server is // TLSMandatory requires that the connection to the server is
// encrypting using STARTTLS. If the server does not support STARTTLS // encrypting using STARTTLS. If the server does not support STARTTLS
// the connection will be terminated with an error // the connection will be terminated with an error
TLSMandatory TLSPolicy = iota TLSMandatory TLSPolicy = iota
// TLSOpportunistic tries cto establish an encrypted connection via the // TLSOpportunistic tries to establish an encrypted connection via the
// STARTTLS protocol. If the server does not support this, it will fall // STARTTLS protocol. If the server does not support this, it will fall
// back cto non-encrypted plaintext transmission // back to non-encrypted plaintext transmission
TLSOpportunistic TLSOpportunistic
// NoTLS forces the transaction cto be not encrypted // NoTLS forces the transaction to be not encrypted
NoTLS NoTLS
) )