diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index e1400c1..8212de2 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2022 Winni Neessen # -# SPDX-License-Identifier: CC0-1.0 +# SPDX-License-Identifier: MIT github: wneessen ko_fi: winni diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index aa1cd05..2117d21 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2022 Winni Neessen # -# SPDX-License-Identifier: CC0-1.0 +# SPDX-License-Identifier: MIT name: Bug Report description: Create a report to help us improve diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index f6ea6d2..f086a05 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2022 Winni Neessen # -# SPDX-License-Identifier: CC0-1.0 +# SPDX-License-Identifier: MIT blank_issues_enabled: false contact_links: diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index d2dae5e..e44cbd5 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2022 Winni Neessen # -# SPDX-License-Identifier: CC0-1.0 +# SPDX-License-Identifier: MIT name: Feature request description: Suggest an idea for this project diff --git a/.github/dependabot.yml b/.github/dependabot.yml index c29d473..5c38fa1 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2022-2023 The go-mail Authors # -# SPDX-License-Identifier: CC0-1.0 +# SPDX-License-Identifier: MIT version: 2 updates: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 37e0363..3c8911b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,12 +33,14 @@ jobs: PERFORM_ONLINE_TEST: ${{ vars.PERFORM_ONLINE_TEST }} PERFORM_UNIX_OPEN_WRITE_TESTS: ${{ vars.PERFORM_UNIX_OPEN_WRITE_TESTS }} PERFORM_SENDMAIL_TESTS: ${{ vars.PERFORM_SENDMAIL_TESTS }} + TEST_BASEPORT: ${{ vars.TEST_BASEPORT }} + TEST_BASEPORT_SMTP: ${{ vars.TEST_BASEPORT_SMTP }} TEST_HOST: ${{ secrets.TEST_HOST }} TEST_USER: ${{ secrets.TEST_USER }} TEST_PASS: ${{ secrets.TEST_PASS }} steps: - name: Harden Runner - uses: step-security/harden-runner@91182cccc01eb5e619899d80e4e971d6181294a7 # v2.10.1 + uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 with: egress-policy: audit - name: Checkout Code @@ -50,14 +52,14 @@ jobs: check-latest: true - name: Install sendmail run: | - sudo apt-get -y update >/dev/null && sudo apt-get -y upgrade >/dev/null && sudo DEBIAN_FRONTEND=noninteractive apt-get -y install nullmailer >/dev/null && which sendmail + sudo apt-get -y update && sudo DEBIAN_FRONTEND=noninteractive apt-get -y install nullmailer && which sendmail - name: Run go test if: success() run: | go test -race -shuffle=on --coverprofile=coverage.coverprofile --covermode=atomic ./... - name: Upload coverage to Codecov if: success() - uses: codecov/codecov-action@b9fd7d16f6d7d1b5d2bec1a2887e65ceed900238 # v4.6.0 + uses: codecov/codecov-action@7f8b4b4bde536c465e797be725718b88c5d95e0e # v5.1.1 with: token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos lint: @@ -71,7 +73,7 @@ jobs: go: ['1.23'] steps: - name: Harden Runner - uses: step-security/harden-runner@91182cccc01eb5e619899d80e4e971d6181294a7 # v2.10.1 + uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 with: egress-policy: audit - name: Setup go @@ -93,13 +95,13 @@ jobs: cancel-in-progress: true steps: - name: Harden Runner - uses: step-security/harden-runner@91182cccc01eb5e619899d80e4e971d6181294a7 # v2.10.1 + uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 with: egress-policy: audit - name: Checkout Code uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master - name: 'Dependency Review' - uses: actions/dependency-review-action@4081bf99e2866ebe428fc0477b69eb4fcda7220a # v4.4.0 + uses: actions/dependency-review-action@3b139cfc5fae8b618d3eae3675e383bb1769c019 # v4.5.0 with: base-ref: ${{ github.event.pull_request.base.sha || 'main' }} head-ref: ${{ github.event.pull_request.head.sha || github.ref }} @@ -111,7 +113,7 @@ jobs: cancel-in-progress: true steps: - name: Harden Runner - uses: step-security/harden-runner@91182cccc01eb5e619899d80e4e971d6181294a7 # v2.10.1 + uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 with: egress-policy: audit - name: Run govulncheck @@ -126,9 +128,12 @@ jobs: matrix: os: [ubuntu-latest, macos-latest, windows-latest] go: ['1.19', '1.20', '1.21', '1.22', '1.23'] + env: + TEST_BASEPORT: ${{ vars.TEST_BASEPORT }} + TEST_BASEPORT_SMTP: ${{ vars.TEST_BASEPORT_SMTP }} steps: - name: Harden Runner - uses: step-security/harden-runner@91182cccc01eb5e619899d80e4e971d6181294a7 # v2.10.1 + uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 with: egress-policy: audit - name: Checkout Code @@ -149,11 +154,14 @@ jobs: strategy: matrix: osver: ['14.1', '14.0', 13.4'] + env: + TEST_BASEPORT: ${{ vars.TEST_BASEPORT }} + TEST_BASEPORT_SMTP: ${{ vars.TEST_BASEPORT_SMTP }} steps: - name: Checkout Code uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master - name: Run go test on FreeBSD - uses: vmactions/freebsd-vm@v1 + uses: vmactions/freebsd-vm@debf37ca7b7fa40e19c542ef7ba30d6054a706a4 # v1.1.5 with: usesh: true copyback: false @@ -170,13 +178,13 @@ jobs: cancel-in-progress: true steps: - name: Harden Runner - uses: step-security/harden-runner@91182cccc01eb5e619899d80e4e971d6181294a7 # v2.10.1 + uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 with: egress-policy: audit - name: Checkout Code uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master - name: REUSE Compliance Check - uses: fsfe/reuse-action@3ae3c6bdf1257ab19397fab11fd3312144692083 # v4.0.0 + uses: fsfe/reuse-action@bb774aa972c2a89ff34781233d275075cbddf542 # v5.0.0 sonarqube: name: Test with SonarQube review (${{ matrix.os }} / ${{ matrix.go }}) runs-on: ${{ matrix.os }} @@ -189,12 +197,14 @@ jobs: go: ['1.23'] env: PERFORM_ONLINE_TEST: ${{ vars.PERFORM_ONLINE_TEST }} + TEST_BASEPORT: ${{ vars.TEST_BASEPORT }} + TEST_BASEPORT_SMTP: ${{ vars.TEST_BASEPORT_SMTP }} TEST_HOST: ${{ secrets.TEST_HOST }} TEST_USER: ${{ secrets.TEST_USER }} TEST_PASS: ${{ secrets.TEST_PASS }} steps: - name: Harden Runner - uses: step-security/harden-runner@91182cccc01eb5e619899d80e4e971d6181294a7 # v2.10.1 + uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 with: egress-policy: audit - name: Checkout Code @@ -208,7 +218,7 @@ jobs: run: | go test -shuffle=on -race --coverprofile=./cov.out ./... - name: SonarQube scan - uses: sonarsource/sonarqube-scan-action@884b79409bbd464b2a59edc326a4b77dc56b2195 # master + uses: sonarsource/sonarqube-scan-action@1b442ee39ac3fa7c2acdd410208dcb2bcfaae6c4 # master if: success() env: SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 01b02ec..24194c7 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2022 Winni Neessen # -# SPDX-License-Identifier: CC0-1.0 +# SPDX-License-Identifier: MIT # For most projects, this workflow file will not need changing; you simply need # to commit it to your repository. @@ -45,7 +45,7 @@ jobs: steps: - name: Harden Runner - uses: step-security/harden-runner@91182cccc01eb5e619899d80e4e971d6181294a7 # v2.10.1 + uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 with: egress-policy: audit @@ -54,7 +54,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@662472033e021d55d94146f66f6058822b0b39fd # v3.27.0 + uses: github/codeql-action/init@aa578102511db1f4524ed59b8cc2bae4f6e88195 # v3.27.6 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -65,7 +65,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@662472033e021d55d94146f66f6058822b0b39fd # v3.27.0 + uses: github/codeql-action/autobuild@aa578102511db1f4524ed59b8cc2bae4f6e88195 # v3.27.6 # ℹī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -79,4 +79,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@662472033e021d55d94146f66f6058822b0b39fd # v3.27.0 + uses: github/codeql-action/analyze@aa578102511db1f4524ed59b8cc2bae4f6e88195 # v3.27.6 diff --git a/.github/workflows/scorecards.yml b/.github/workflows/scorecards.yml index 0d5ccfd..3fe1677 100644 --- a/.github/workflows/scorecards.yml +++ b/.github/workflows/scorecards.yml @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2022-2023 The go-mail Authors # -# SPDX-License-Identifier: CC0-1.0 +# SPDX-License-Identifier: MIT # This workflow uses actions that are not certified by GitHub. They are provided # by a third-party and are governed by separate terms of service, privacy @@ -35,7 +35,7 @@ jobs: steps: - name: Harden Runner - uses: step-security/harden-runner@91182cccc01eb5e619899d80e4e971d6181294a7 # v2.10.1 + uses: step-security/harden-runner@0080882f6c36860b6ba35c610c98ce87d4e2f26f # v2.10.2 with: egress-policy: audit @@ -75,6 +75,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@662472033e021d55d94146f66f6058822b0b39fd # v3.27.0 + uses: github/codeql-action/upload-sarif@aa578102511db1f4524ed59b8cc2bae4f6e88195 # v3.27.6 with: sarif_file: results.sarif diff --git a/.gitignore b/.gitignore index 5ce8347..ad05dc0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2022 Winni Neessen # -# SPDX-License-Identifier: CC0-1.0 +# SPDX-License-Identifier: MIT # Binaries for programs and plugins *.exe diff --git a/.golangci.toml b/.golangci.toml index 223dc0b..5178a27 100644 --- a/.golangci.toml +++ b/.golangci.toml @@ -9,4 +9,73 @@ exclude-dirs = ["examples"] [linters] enable = ["stylecheck", "whitespace", "containedctx", "contextcheck", "decorder", - "errname", "errorlint", "gofmt", "gofumpt"] + "errname", "errorlint", "gofmt", "gofumpt", "gosec"] + +[issues] + +## An overflow is impossible here +[[issues.exclude-rules]] +linters = ["gosec"] +path = "random.go" +text = "G115:" + +## These are tests which intentionally do not need any TLS settings +[[issues.exclude-rules]] +linters = ["gosec"] +path = "client_test.go" +text = "G402:" + +## These are tests which intentionally do not need any TLS settings +[[issues.exclude-rules]] +linters = ["gosec"] +path = "smtp/smtp_test.go" +text = "G402:" + +## We do not dictate a TLS minimum version in the smtp package. go-mail +## itself does set sane defaults +[[issues.exclude-rules]] +linters = ["gosec"] +path = "smtp/smtp.go" +text = "G402:" + +## The chance that we write +2 million tests is very low, I think we can +## ignore this for the time being +[[issues.exclude-rules]] +linters = ["gosec"] +path = "client_test.go" +text = "G109:" + +## The chance that we write +2 million tests is very low, I think we can +## ignore this for the time being +[[issues.exclude-rules]] +linters = ["gosec"] +path = "smtp/smtp_test.go" +text = "G109:" + +## We inform the user about the deprecated status of CRAM-MD5 and suggest +## to use SCRAM-SHA instead +[[issues.exclude-rules]] +linters = ["gosec"] +path = "smtp/auth_cram_md5.go" +text = "G501:" + +## Yes, SHA1 is weak, but in the context of SCRAM it is still considered +## secure for specific applications. The user is information about this +## in the documentation +[[issues.exclude-rules]] +linters = ["gosec"] +path = "smtp/auth_scram.go" +text = "G505:" + +## Test code for SCRAM-SHA1. Can be ignored. +[[issues.exclude-rules]] +linters = ["gosec"] +path = "smtp/smtp_test.go" +text = "G505:" + +## These are tests which intentionally do not need any TLS settings +[[issues.exclude-rules]] +linters = ["gosec"] +path = "quicksend_test.go" +text = "G402:" + diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index be485b6..a188a88 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,7 +1,7 @@ # Contributor Covenant Code of Conduct diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 959134a..c8f7c00 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,7 +1,7 @@ # How to contribute diff --git a/LICENSES/CC0-1.0.txt b/LICENSES/CC0-1.0.txt deleted file mode 100644 index 0e259d4..0000000 --- a/LICENSES/CC0-1.0.txt +++ /dev/null @@ -1,121 +0,0 @@ -Creative Commons Legal Code - -CC0 1.0 Universal - - CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE - LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN - ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS - INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES - REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS - PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM - THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED - HEREUNDER. - -Statement of Purpose - -The laws of most jurisdictions throughout the world automatically confer -exclusive Copyright and Related Rights (defined below) upon the creator -and subsequent owner(s) (each and all, an "owner") of an original work of -authorship and/or a database (each, a "Work"). - -Certain owners wish to permanently relinquish those rights to a Work for -the purpose of contributing to a commons of creative, cultural and -scientific works ("Commons") that the public can reliably and without fear -of later claims of infringement build upon, modify, incorporate in other -works, reuse and redistribute as freely as possible in any form whatsoever -and for any purposes, including without limitation commercial purposes. -These owners may contribute to the Commons to promote the ideal of a free -culture and the further production of creative, cultural and scientific -works, or to gain reputation or greater distribution for their Work in -part through the use and efforts of others. - -For these and/or other purposes and motivations, and without any -expectation of additional consideration or compensation, the person -associating CC0 with a Work (the "Affirmer"), to the extent that he or she -is an owner of Copyright and Related Rights in the Work, voluntarily -elects to apply CC0 to the Work and publicly distribute the Work under its -terms, with knowledge of his or her Copyright and Related Rights in the -Work and the meaning and intended legal effect of CC0 on those rights. - -1. Copyright and Related Rights. A Work made available under CC0 may be -protected by copyright and related or neighboring rights ("Copyright and -Related Rights"). Copyright and Related Rights include, but are not -limited to, the following: - - i. the right to reproduce, adapt, distribute, perform, display, - communicate, and translate a Work; - ii. moral rights retained by the original author(s) and/or performer(s); -iii. publicity and privacy rights pertaining to a person's image or - likeness depicted in a Work; - iv. rights protecting against unfair competition in regards to a Work, - subject to the limitations in paragraph 4(a), below; - v. rights protecting the extraction, dissemination, use and reuse of data - in a Work; - vi. database rights (such as those arising under Directive 96/9/EC of the - European Parliament and of the Council of 11 March 1996 on the legal - protection of databases, and under any national implementation - thereof, including any amended or successor version of such - directive); and -vii. other similar, equivalent or corresponding rights throughout the - world based on applicable law or treaty, and any national - implementations thereof. - -2. Waiver. To the greatest extent permitted by, but not in contravention -of, applicable law, Affirmer hereby overtly, fully, permanently, -irrevocably and unconditionally waives, abandons, and surrenders all of -Affirmer's Copyright and Related Rights and associated claims and causes -of action, whether now known or unknown (including existing as well as -future claims and causes of action), in the Work (i) in all territories -worldwide, (ii) for the maximum duration provided by applicable law or -treaty (including future time extensions), (iii) in any current or future -medium and for any number of copies, and (iv) for any purpose whatsoever, -including without limitation commercial, advertising or promotional -purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each -member of the public at large and to the detriment of Affirmer's heirs and -successors, fully intending that such Waiver shall not be subject to -revocation, rescission, cancellation, termination, or any other legal or -equitable action to disrupt the quiet enjoyment of the Work by the public -as contemplated by Affirmer's express Statement of Purpose. - -3. Public License Fallback. Should any part of the Waiver for any reason -be judged legally invalid or ineffective under applicable law, then the -Waiver shall be preserved to the maximum extent permitted taking into -account Affirmer's express Statement of Purpose. In addition, to the -extent the Waiver is so judged Affirmer hereby grants to each affected -person a royalty-free, non transferable, non sublicensable, non exclusive, -irrevocable and unconditional license to exercise Affirmer's Copyright and -Related Rights in the Work (i) in all territories worldwide, (ii) for the -maximum duration provided by applicable law or treaty (including future -time extensions), (iii) in any current or future medium and for any number -of copies, and (iv) for any purpose whatsoever, including without -limitation commercial, advertising or promotional purposes (the -"License"). The License shall be deemed effective as of the date CC0 was -applied by Affirmer to the Work. Should any part of the License for any -reason be judged legally invalid or ineffective under applicable law, such -partial invalidity or ineffectiveness shall not invalidate the remainder -of the License, and in such case Affirmer hereby affirms that he or she -will not (i) exercise any of his or her remaining Copyright and Related -Rights in the Work or (ii) assert any associated claims and causes of -action with respect to the Work, in either case contrary to Affirmer's -express Statement of Purpose. - -4. Limitations and Disclaimers. - - a. No trademark or patent rights held by Affirmer are waived, abandoned, - surrendered, licensed or otherwise affected by this document. - b. Affirmer offers the Work as-is and makes no representations or - warranties of any kind concerning the Work, express, implied, - statutory or otherwise, including without limitation warranties of - title, merchantability, fitness for a particular purpose, non - infringement, or the absence of latent or other defects, accuracy, or - the present or absence of errors, whether or not discoverable, all to - the greatest extent permissible under applicable law. - c. Affirmer disclaims responsibility for clearing rights of other persons - that may apply to the Work or any use thereof, including without - limitation any person's Copyright and Related Rights in the Work. - Further, Affirmer disclaims responsibility for obtaining any necessary - consents, permissions or other rights required for any use of the - Work. - d. Affirmer understands and acknowledges that Creative Commons is not a - party to this document and has no duty or obligation with respect to - this CC0 or use of the Work. diff --git a/README.md b/README.md index 9ab4303..eebfc06 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # go-mail - Easy to use, yet comprehensive library for sending mails with Go @@ -39,14 +39,20 @@ Here are some highlights of go-mail's featureset: * [X] Very small dependency footprint (mainly Go Stdlib and Go extended packages) * [X] Modern, idiomatic Go * [X] Sane and secure defaults -* [X] Explicit SSL/TLS support -* [X] Implicit StartTLS support with different policies +* [X] Implicit SSL/TLS support +* [X] Explicit STARTTLS support with different policies * [X] Makes use of contexts for a better control flow and timeout/cancelation handling -* [X] SMTP Auth support (LOGIN, PLAIN, CRAM-MD, XOAUTH2, SCRAM-SHA-1(-PLUS), SCRAM-SHA-256(-PLUS)) +* [X] SMTP Auth support + * [X] CRAM-MD5 + * [X] LOGIN + * [X] PLAIN + * [X] SCRAM-SHA-1/SCRAM-SHA-1-PLUS + * [X] SCRAM-SHA-256/SCRAM-SHA-256-PLUS + * [X] XOAUTH2 * [X] RFC5322 compliant mail address validation * [X] Support for common mail header field generation (Message-ID, Date, Bulk-Precedence, Priority, etc.) * [X] Concurrency-safe reusing the same SMTP connection to send multiple mails -* [X] Support for attachments and inline embeds (from file system, `io.Reader` or `embed.FS`) +* [X] Support for attachments and inline embeds (from file system, `io.Reader`, `embed.FS` or `fs.FS`) * [X] Support for different encodings * [X] Middleware support for 3rd-party libraries to alter mail messages * [X] Support sending mails via a local sendmail command diff --git a/SECURITY.md b/SECURITY.md index e2059c2..b0ed1ca 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,7 +1,7 @@ # Security Policy diff --git a/auth.go b/auth.go index 66254ee..3c520ae 100644 --- a/auth.go +++ b/auth.go @@ -136,6 +136,21 @@ const ( // // https://datatracker.ietf.org/doc/html/rfc7677 SMTPAuthSCRAMSHA256PLUS SMTPAuthType = "SCRAM-SHA-256-PLUS" + + // SMTPAuthAutoDiscover is a mechanism that dynamically discovers all authentication mechanisms + // supported by the SMTP server and selects the strongest available one. + // + // This type simplifies authentication by automatically negotiating the most secure mechanism + // offered by the server, based on a predefined security ranking. For instance, mechanisms like + // SCRAM-SHA-256(-PLUS) or XOAUTH2 are prioritized over weaker mechanisms such as CRAM-MD5 or PLAIN. + // + // The negotiation process ensures that mechanisms requiring additional capabilities (e.g., + // SCRAM-SHA-X-PLUS with TLS channel binding) are only selected when the necessary prerequisites + // are in place, such as an active TLS-secured connection. + // + // By automating mechanism selection, SMTPAuthAutoDiscover minimizes configuration effort while + // maximizing security and compatibility with a wide range of SMTP servers. + SMTPAuthAutoDiscover SMTPAuthType = "AUTODISCOVER" ) // SMTP Auth related static errors @@ -170,12 +185,19 @@ var ( // ErrSCRAMSHA256PLUSAuthNotSupported is returned when the server does not support the "SCRAM-SHA-256-PLUS" SMTP // authentication type. ErrSCRAMSHA256PLUSAuthNotSupported = errors.New("server does not support SMTP AUTH type: SCRAM-SHA-256-PLUS") + + // ErrNoSupportedAuthDiscovered is returned when the SMTP Auth AutoDiscover process fails to identify + // any supported authentication mechanisms offered by the server. + ErrNoSupportedAuthDiscovered = errors.New("SMTP Auth autodiscover was not able to detect a supported " + + "authentication mechanism") ) // UnmarshalString satisfies the fig.StringUnmarshaler interface for the SMTPAuthType type // https://pkg.go.dev/github.com/kkyr/fig#StringUnmarshaler func (sa *SMTPAuthType) UnmarshalString(value string) error { switch strings.ToLower(value) { + case "auto", "autodiscover", "autodiscovery": + *sa = SMTPAuthAutoDiscover case "cram-md5", "crammd5", "cram": *sa = SMTPAuthCramMD5 case "custom": diff --git a/auth_test.go b/auth_test.go index a73eaca..a687af3 100644 --- a/auth_test.go +++ b/auth_test.go @@ -12,6 +12,9 @@ func TestSMTPAuthType_UnmarshalString(t *testing.T) { authString string expected SMTPAuthType }{ + {"AUTODISCOVER: auto", "auto", SMTPAuthAutoDiscover}, + {"AUTODISCOVER: autodiscover", "autodiscover", SMTPAuthAutoDiscover}, + {"AUTODISCOVER: autodiscovery", "autodiscovery", SMTPAuthAutoDiscover}, {"CRAM-MD5: cram-md5", "cram-md5", SMTPAuthCramMD5}, {"CRAM-MD5: crammd5", "crammd5", SMTPAuthCramMD5}, {"CRAM-MD5: cram", "cram", SMTPAuthCramMD5}, diff --git a/client.go b/client.go index abb90f4..028e277 100644 --- a/client.go +++ b/client.go @@ -142,9 +142,6 @@ type ( // host is the hostname of the SMTP server we are connecting to. host string - // isEncrypted indicates wether the Client connection is encrypted or not. - isEncrypted bool - // logAuthData indicates whether authentication-related data should be logged. logAuthData bool @@ -170,6 +167,9 @@ type ( // requestDSN indicates wether we want to request DSN (Delivery Status Notifications). requestDSN bool + // sendMutex is used to synchronize access to shared resources during the dial and send methods. + sendMutex sync.Mutex + // smtpAuth is the authentication type that is used to authenticate the user with SMTP server. It // satisfies the smtp.Auth interface. // @@ -803,7 +803,7 @@ func (c *Client) SetSSLPort(ssl bool, fallback bool) { } // SetDebugLog sets or overrides whether the Client is using debug logging. The debug logger will log incoming -// and outgoing communication between the Client and the server to os.Stderr. +// and outgoing communication between the client and the server to log.Logger that is defined on the Client. // // Note: The SMTP communication might include unencrypted authentication data, depending on whether you are using // SMTP authentication and the type of authentication mechanism. This could pose a data protection risk. Use @@ -812,9 +812,26 @@ func (c *Client) SetSSLPort(ssl bool, fallback bool) { // Parameters: // - val: A boolean value indicating whether to enable (true) or disable (false) debug logging. func (c *Client) SetDebugLog(val bool) { + c.SetDebugLogWithSMTPClient(c.smtpClient, val) +} + +// SetDebugLogWithSMTPClient sets or overrides whether the provided smtp.Client is using debug logging. +// The debug logger will log incoming and outgoing communication between the client and the server to +// log.Logger that is defined on the Client. +// +// Note: The SMTP communication might include unencrypted authentication data, depending on whether you are using +// SMTP authentication and the type of authentication mechanism. This could pose a data protection risk. Use +// debug logging with caution. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// - val: A boolean value indicating whether to enable (true) or disable (false) debug logging. +func (c *Client) SetDebugLogWithSMTPClient(client *smtp.Client, val bool) { + c.mutex.Lock() + defer c.mutex.Unlock() c.useDebugLog = val - if c.smtpClient != nil { - c.smtpClient.SetDebugLog(val) + if client != nil { + client.SetDebugLog(val) } } @@ -827,9 +844,24 @@ func (c *Client) SetDebugLog(val bool) { // Parameters: // - logger: A logger that satisfies the log.Logger interface to be set for the Client. func (c *Client) SetLogger(logger log.Logger) { + c.SetLoggerWithSMTPClient(c.smtpClient, logger) +} + +// SetLoggerWithSMTPClient sets or overrides the custom logger currently used by the provided smtp.Client. +// The logger must satisfy the log.Logger interface and is only utilized when debug logging is enabled on +// the provided smtp.Client. +// +// By default, log.Stdlog is used if no custom logger is provided. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// - logger: A logger that satisfies the log.Logger interface to be set for the Client. +func (c *Client) SetLoggerWithSMTPClient(client *smtp.Client, logger log.Logger) { + c.mutex.Lock() + defer c.mutex.Unlock() c.logger = logger - if c.smtpClient != nil { - c.smtpClient.SetLogger(logger) + if client != nil { + client.SetLogger(logger) } } @@ -923,67 +955,91 @@ func (c *Client) SetLogAuthData(logAuth bool) { // SMTP server. // // Parameters: -// - dialCtx: The context.Context used to control the connection timeout and cancellation. +// - ctxDial: The context.Context used to control the connection timeout and cancellation. // // Returns: // - An error if the connection to the SMTP server fails or any subsequent command fails. -func (c *Client) DialWithContext(dialCtx context.Context) error { - c.mutex.Lock() - defer c.mutex.Unlock() - - ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout)) - defer cancel() - - if c.dialContextFunc == nil { - netDialer := net.Dialer{} - c.dialContextFunc = netDialer.DialContext - - if c.useSSL { - tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig} - c.isEncrypted = true - c.dialContextFunc = tlsDialer.DialContext - } - } - connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr()) - if err != nil && c.fallbackPort != 0 { - // TODO: should we somehow log or append the previous error? - connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) - } +func (c *Client) DialWithContext(ctxDial context.Context) error { + client, err := c.DialToSMTPClientWithContext(ctxDial) if err != nil { return err } + c.mutex.Lock() + c.smtpClient = client + c.mutex.Unlock() + return nil +} + +// DialToSMTPClientWithContext establishes and configures a smtp.Client connection using +// the provided context. +// +// This function uses the provided context to manage the connection deadline and cancellation. +// It dials the SMTP server using the Client's configured DialContextFunc or a default dialer. +// If SSL is enabled, it uses a TLS connection. After successfully connecting, it initializes +// an smtp.Client, sends the HELO/EHLO command, and optionally performs STARTTLS and SMTP AUTH +// based on the Client's configuration. Debug and authentication logging are enabled if +// configured. +// +// Parameters: +// - ctxDial: The context used to control the connection timeout and cancellation. +// +// Returns: +// - A pointer to the initialized smtp.Client. +// - An error if the connection fails, the smtp.Client cannot be created, or any subsequent commands fail. +func (c *Client) DialToSMTPClientWithContext(ctxDial context.Context) (*smtp.Client, error) { + c.mutex.RLock() + defer c.mutex.RUnlock() + + ctx, cancel := context.WithDeadline(ctxDial, time.Now().Add(c.connTimeout)) + defer cancel() + + isEncrypted := false + dialContextFunc := c.dialContextFunc + if c.dialContextFunc == nil { + netDialer := net.Dialer{} + dialContextFunc = netDialer.DialContext + if c.useSSL { + tlsDialer := tls.Dialer{NetDialer: &netDialer, Config: c.tlsconfig} + isEncrypted = true + dialContextFunc = tlsDialer.DialContext + } + } + connection, err := dialContextFunc(ctx, "tcp", c.ServerAddr()) + if err != nil && c.fallbackPort != 0 { + // TODO: should we somehow log or append the previous error? + connection, err = dialContextFunc(ctx, "tcp", c.serverFallbackAddr()) + } + if err != nil { + return nil, err + } client, err := smtp.NewClient(connection, c.host) if err != nil { - return err + return nil, err } - if client == nil { - return fmt.Errorf("SMTP client is nil") - } - c.smtpClient = client if c.logger != nil { - c.smtpClient.SetLogger(c.logger) + client.SetLogger(c.logger) } if c.useDebugLog { - c.smtpClient.SetDebugLog(true) + client.SetDebugLog(true) } if c.logAuthData { - c.smtpClient.SetLogAuthData() + client.SetLogAuthData() } - if err = c.smtpClient.Hello(c.helo); err != nil { - return err + if err = client.Hello(c.helo); err != nil { + return nil, err } - if err = c.tls(); err != nil { - return err + if err = c.tls(client, &isEncrypted); err != nil { + return nil, err } - if err = c.auth(); err != nil { - return err + if err = c.auth(client, isEncrypted); err != nil { + return nil, err } - return nil + return client, nil } // Close terminates the connection to the SMTP server, returning an error if the disconnection @@ -996,10 +1052,27 @@ func (c *Client) DialWithContext(dialCtx context.Context) error { // Returns: // - An error if the disconnection fails; otherwise, returns nil. func (c *Client) Close() error { - if !c.smtpClient.HasConnection() { + return c.CloseWithSMTPClient(c.smtpClient) +} + +// CloseWithSMTPClient terminates the connection of the provided smtp.Client to the SMTP server, +// returning an error if the disconnection fails. If the connection is already closed, this +// method is a no-op and disregards any error. +// +// This function checks if the smtp.Client connection is active. If not, it simply returns +// without any action. If the connection is active, it attempts to gracefully close the +// connection using the Quit method. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// +// Returns: +// - An error if the disconnection fails; otherwise, returns nil. +func (c *Client) CloseWithSMTPClient(client *smtp.Client) error { + if client == nil || !client.HasConnection() { return nil } - if err := c.smtpClient.Quit(); err != nil { + if err := client.Quit(); err != nil { return fmt.Errorf("failed to close SMTP client: %w", err) } @@ -1013,12 +1086,29 @@ func (c *Client) Close() error { // the command fails, an error is returned. // // Returns: -// - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil. +// - An error if the connection check fails or if sending the RSET command fails; +// otherwise, returns nil. func (c *Client) Reset() error { - if err := c.checkConn(); err != nil { + return c.ResetWithSMTPClient(c.smtpClient) +} + +// ResetWithSMTPClient sends an SMTP RSET command to the provided smtp.Client, to reset +// the state of the current SMTP session. +// +// This method checks the connection to the SMTP server and, if the connection is valid, +// it sends an RSET command to reset the session state. If the connection is invalid or +// the command fails, an error is returned. +// +// Parameters: +// - client: A pointer to the smtp.Client that handles the connection to the server. +// +// Returns: +// - An error if the connection check fails or if sending the RSET command fails; otherwise, returns nil. +func (c *Client) ResetWithSMTPClient(client *smtp.Client) error { + if err := c.checkConn(client); err != nil { return err } - if err := c.smtpClient.Reset(); err != nil { + if err := client.Reset(); err != nil { return fmt.Errorf("failed to send RSET to SMTP client: %w", err) } @@ -1058,22 +1148,47 @@ func (c *Client) DialAndSend(messages ...*Msg) error { // - An error if the connection fails, if sending the messages fails, or if closing the // connection fails; otherwise, returns nil. func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) error { - if err := c.DialWithContext(ctx); err != nil { + client, err := c.DialToSMTPClientWithContext(ctx) + if err != nil { return fmt.Errorf("dial failed: %w", err) } defer func() { - _ = c.Close() + _ = c.CloseWithSMTPClient(client) }() - if err := c.Send(messages...); err != nil { + if err = c.SendWithSMTPClient(client, messages...); err != nil { return fmt.Errorf("send failed: %w", err) } - if err := c.Close(); err != nil { + if err = c.CloseWithSMTPClient(client); err != nil { return fmt.Errorf("failed to close connection: %w", err) } return nil } +// Send attempts to send one or more Msg using the SMTP client that is assigned to the Client. +// If the Client has no active connection to the server, Send will fail with an error. For +// each of the provided Msg, it will associate a SendError with the Msg in case of a +// transmission or delivery error. +// +// This method first checks for an active connection to the SMTP server. If the connection is +// not valid, it returns a SendError. It then iterates over the provided messages, attempting +// to send each one. If an error occurs during sending, the method records the error and +// associates it with the corresponding Msg. If multiple errors are encountered, it aggregates +// them into a single SendError to be returned. +// +// Parameters: +// - client: A pointer to the smtp.Client that holds the connection to the SMTP server +// - messages: A variadic list of pointers to Msg objects to be sent. +// +// Returns: +// - An error that represents the sending result, which may include multiple SendErrors if +// any occurred; otherwise, returns nil. +func (c *Client) Send(messages ...*Msg) (returnErr error) { + c.sendMutex.Lock() + defer c.sendMutex.Unlock() + return c.SendWithSMTPClient(c.smtpClient, messages...) +} + // auth attempts to authenticate the client using SMTP AUTH mechanisms. It checks the connection, // determines the supported authentication methods, and applies the appropriate authentication // type. An error is returned if authentication fails. @@ -1093,85 +1208,125 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e // Returns: // - An error if the connection check fails, if no supported authentication method is found, // or if the authentication process fails. -func (c *Client) auth() error { +func (c *Client) auth(client *smtp.Client, isEnc bool) error { + var smtpAuth smtp.Auth if c.smtpAuth == nil && c.smtpAuthType != SMTPAuthNoAuth { - hasSMTPAuth, smtpAuthType := c.smtpClient.Extension("AUTH") + hasSMTPAuth, smtpAuthType := client.Extension("AUTH") if !hasSMTPAuth { return fmt.Errorf("server does not support SMTP AUTH") } - switch c.smtpAuthType { + authType := c.smtpAuthType + if c.smtpAuthType == SMTPAuthAutoDiscover { + discoveredType, err := c.authTypeAutoDiscover(smtpAuthType, isEnc) + if err != nil { + return err + } + authType = discoveredType + } + + switch authType { case SMTPAuthPlain: if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) { return ErrPlainAuthNotSupported } - c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false) + smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, false) case SMTPAuthPlainNoEnc: if !strings.Contains(smtpAuthType, string(SMTPAuthPlain)) { return ErrPlainAuthNotSupported } - c.smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true) + smtpAuth = smtp.PlainAuth("", c.user, c.pass, c.host, true) case SMTPAuthLogin: if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) { return ErrLoginAuthNotSupported } - c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false) + smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, false) case SMTPAuthLoginNoEnc: if !strings.Contains(smtpAuthType, string(SMTPAuthLogin)) { return ErrLoginAuthNotSupported } - c.smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true) + smtpAuth = smtp.LoginAuth(c.user, c.pass, c.host, true) case SMTPAuthCramMD5: if !strings.Contains(smtpAuthType, string(SMTPAuthCramMD5)) { return ErrCramMD5AuthNotSupported } - c.smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass) + smtpAuth = smtp.CRAMMD5Auth(c.user, c.pass) case SMTPAuthXOAUTH2: if !strings.Contains(smtpAuthType, string(SMTPAuthXOAUTH2)) { return ErrXOauth2AuthNotSupported } - c.smtpAuth = smtp.XOAuth2Auth(c.user, c.pass) + smtpAuth = smtp.XOAuth2Auth(c.user, c.pass) case SMTPAuthSCRAMSHA1: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1)) { return ErrSCRAMSHA1AuthNotSupported } - c.smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass) + smtpAuth = smtp.ScramSHA1Auth(c.user, c.pass) case SMTPAuthSCRAMSHA256: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256)) { return ErrSCRAMSHA256AuthNotSupported } - c.smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass) + smtpAuth = smtp.ScramSHA256Auth(c.user, c.pass) case SMTPAuthSCRAMSHA1PLUS: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA1PLUS)) { return ErrSCRAMSHA1PLUSAuthNotSupported } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { return err } - c.smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState) + smtpAuth = smtp.ScramSHA1PlusAuth(c.user, c.pass, tlsConnState) case SMTPAuthSCRAMSHA256PLUS: if !strings.Contains(smtpAuthType, string(SMTPAuthSCRAMSHA256PLUS)) { return ErrSCRAMSHA256PLUSAuthNotSupported } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { return err } - c.smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState) + smtpAuth = smtp.ScramSHA256PlusAuth(c.user, c.pass, tlsConnState) default: return fmt.Errorf("unsupported SMTP AUTH type %q", c.smtpAuthType) } } - if c.smtpAuth != nil { - if err := c.smtpClient.Auth(c.smtpAuth); err != nil { + if smtpAuth != nil { + if err := client.Auth(smtpAuth); err != nil { return fmt.Errorf("SMTP AUTH failed: %w", err) } } return nil } +func (c *Client) authTypeAutoDiscover(supported string, isEnc bool) (SMTPAuthType, error) { + if supported == "" { + return "", ErrNoSupportedAuthDiscovered + } + preferList := []SMTPAuthType{ + SMTPAuthSCRAMSHA256PLUS, SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1PLUS, SMTPAuthSCRAMSHA1, + SMTPAuthXOAUTH2, SMTPAuthCramMD5, SMTPAuthPlain, SMTPAuthLogin, + } + if !isEnc { + preferList = []SMTPAuthType{SMTPAuthSCRAMSHA256, SMTPAuthSCRAMSHA1, SMTPAuthXOAUTH2, SMTPAuthCramMD5} + } + mechs := strings.Split(supported, " ") + + for _, item := range preferList { + if sliceContains(mechs, string(item)) { + return item, nil + } + } + return "", ErrNoSupportedAuthDiscovered +} + +func sliceContains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + // sendSingleMsg sends out a single message and returns an error if the transmission or // delivery fails. It is invoked by the public Send methods. // @@ -1187,12 +1342,13 @@ func (c *Client) auth() error { // // Returns: // - An error if any part of the sending process fails; otherwise, returns nil. -func (c *Client) sendSingleMsg(message *Msg) error { - c.mutex.Lock() - defer c.mutex.Unlock() +func (c *Client) sendSingleMsg(client *smtp.Client, message *Msg) error { + c.mutex.RLock() + defer c.mutex.RUnlock() + escSupport, _ := client.Extension("ENHANCEDSTATUSCODES") if message.encoding == NoEncoding { - if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok { + if ok, _ := client.Extension("8BITMIME"); !ok { return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message} } } @@ -1200,28 +1356,31 @@ func (c *Client) sendSingleMsg(message *Msg) error { if err != nil { return &SendError{ Reason: ErrGetSender, errlist: []error{err}, isTemp: isTempError(err), - affectedMsg: message, + affectedMsg: message, errcode: errorCode(err), + enhancedStatusCode: enhancedStatusCode(err, escSupport), } } rcpts, err := message.GetRecipients() if err != nil { return &SendError{ Reason: ErrGetRcpts, errlist: []error{err}, isTemp: isTempError(err), - affectedMsg: message, + affectedMsg: message, errcode: errorCode(err), + enhancedStatusCode: enhancedStatusCode(err, escSupport), } } if c.requestDSN { if c.dsnReturnType != "" { - c.smtpClient.SetDSNMailReturnOption(string(c.dsnReturnType)) + client.SetDSNMailReturnOption(string(c.dsnReturnType)) } } - if err = c.smtpClient.Mail(from); err != nil { + if err = client.Mail(from); err != nil { retError := &SendError{ Reason: ErrSMTPMailFrom, errlist: []error{err}, isTemp: isTempError(err), - affectedMsg: message, + affectedMsg: message, errcode: errorCode(err), + enhancedStatusCode: enhancedStatusCode(err, escSupport), } - if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil { + if resetSendErr := client.Reset(); resetSendErr != nil { retError.errlist = append(retError.errlist, resetSendErr) } return retError @@ -1231,48 +1390,54 @@ func (c *Client) sendSingleMsg(message *Msg) error { rcptSendErr.errlist = make([]error, 0) rcptSendErr.rcpt = make([]string, 0) rcptNotifyOpt := strings.Join(c.dsnRcptNotifyType, ",") - c.smtpClient.SetDSNRcptNotifyOption(rcptNotifyOpt) + client.SetDSNRcptNotifyOption(rcptNotifyOpt) for _, rcpt := range rcpts { - if err = c.smtpClient.Rcpt(rcpt); err != nil { + if err = client.Rcpt(rcpt); err != nil { rcptSendErr.Reason = ErrSMTPRcptTo rcptSendErr.errlist = append(rcptSendErr.errlist, err) rcptSendErr.rcpt = append(rcptSendErr.rcpt, rcpt) rcptSendErr.isTemp = isTempError(err) + rcptSendErr.errcode = errorCode(err) + rcptSendErr.enhancedStatusCode = enhancedStatusCode(err, escSupport) hasError = true } } if hasError { - if resetSendErr := c.smtpClient.Reset(); resetSendErr != nil { + if resetSendErr := client.Reset(); resetSendErr != nil { rcptSendErr.errlist = append(rcptSendErr.errlist, resetSendErr) } return rcptSendErr } - writer, err := c.smtpClient.Data() + writer, err := client.Data() if err != nil { return &SendError{ Reason: ErrSMTPData, errlist: []error{err}, isTemp: isTempError(err), - affectedMsg: message, + affectedMsg: message, errcode: errorCode(err), + enhancedStatusCode: enhancedStatusCode(err, escSupport), } } _, err = message.WriteTo(writer) if err != nil { return &SendError{ Reason: ErrWriteContent, errlist: []error{err}, isTemp: isTempError(err), - affectedMsg: message, + affectedMsg: message, errcode: errorCode(err), + enhancedStatusCode: enhancedStatusCode(err, escSupport), } } if err = writer.Close(); err != nil { return &SendError{ Reason: ErrSMTPDataClose, errlist: []error{err}, isTemp: isTempError(err), - affectedMsg: message, + affectedMsg: message, errcode: errorCode(err), + enhancedStatusCode: enhancedStatusCode(err, escSupport), } } message.isDelivered = true - if err = c.Reset(); err != nil { + if err = c.ResetWithSMTPClient(client); err != nil { return &SendError{ Reason: ErrSMTPReset, errlist: []error{err}, isTemp: isTempError(err), - affectedMsg: message, + affectedMsg: message, errcode: errorCode(err), + enhancedStatusCode: enhancedStatusCode(err, escSupport), } } return nil @@ -1290,21 +1455,24 @@ func (c *Client) sendSingleMsg(message *Msg) error { // Returns: // - An error if there is no active connection, if the NOOP command fails, or if extending // the deadline fails; otherwise, returns nil. -func (c *Client) checkConn() error { - if c.smtpClient == nil { +func (c *Client) checkConn(client *smtp.Client) error { + if client == nil { return ErrNoActiveConnection } - if !c.smtpClient.HasConnection() { + if !client.HasConnection() { return ErrNoActiveConnection } - if !c.noNoop { - if err := c.smtpClient.Noop(); err != nil { + c.mutex.RLock() + noNoop := c.noNoop + c.mutex.RUnlock() + if !noNoop { + if err := client.Noop(); err != nil { return ErrNoActiveConnection } } - if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil { + if err := client.UpdateDeadline(c.connTimeout); err != nil { return ErrDeadlineExtendFailed } return nil @@ -1351,10 +1519,10 @@ func (c *Client) setDefaultHelo() error { // Returns: // - An error if there is no active connection, if STARTTLS is required but not supported, // or if there are issues during the TLS handshake; otherwise, returns nil. -func (c *Client) tls() error { +func (c *Client) tls(client *smtp.Client, isEnc *bool) error { if !c.useSSL && c.tlspolicy != NoTLS { hasStartTLS := false - extension, _ := c.smtpClient.Extension("STARTTLS") + extension, _ := client.Extension("STARTTLS") if c.tlspolicy == TLSMandatory { hasStartTLS = true if !extension { @@ -1368,21 +1536,21 @@ func (c *Client) tls() error { } } if hasStartTLS { - if err := c.smtpClient.StartTLS(c.tlsconfig); err != nil { + if err := client.StartTLS(c.tlsconfig); err != nil { return err } } - tlsConnState, err := c.smtpClient.GetTLSConnectionState() + tlsConnState, err := client.GetTLSConnectionState() if err != nil { switch { case errors.Is(err, smtp.ErrNonTLSConnection): - c.isEncrypted = false + *isEnc = false return nil default: return fmt.Errorf("failed to get TLS connection state: %w", err) } } - c.isEncrypted = tlsConnState.HandshakeComplete + *isEnc = tlsConnState.HandshakeComplete } return nil } diff --git a/client_119.go b/client_119.go index 093967e..96d18a7 100644 --- a/client_119.go +++ b/client_119.go @@ -7,12 +7,16 @@ package mail -import "errors" +import ( + "errors" -// Send attempts to send one or more Msg using the Client connection to the SMTP server. -// If the Client has no active connection to the server, Send will fail with an error. For each -// of the provided Msg, it will associate a SendError with the Msg in case of a transmission -// or delivery error. + "github.com/wneessen/go-mail/smtp" +) + +// SendWithSMTPClient attempts to send one or more Msg using a provided smtp.Client with an +// established connection to the SMTP server. If the smtp.Client has no active connection to +// the server, SendWithSMTPClient will fail with an error. For each of the provided Msg, it +// will associate a SendError with the Msg in case of a transmission or delivery error. // // This method first checks for an active connection to the SMTP server. If the connection is // not valid, it returns a SendError. It then iterates over the provided messages, attempting @@ -21,18 +25,26 @@ import "errors" // them into a single SendError to be returned. // // Parameters: +// - client: A pointer to the smtp.Client that holds the connection to the SMTP server // - messages: A variadic list of pointers to Msg objects to be sent. // // Returns: // - An error that represents the sending result, which may include multiple SendErrors if // any occurred; otherwise, returns nil. -func (c *Client) Send(messages ...*Msg) error { - if err := c.checkConn(); err != nil { - return &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} +func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) error { + escSupport := false + if client != nil { + escSupport, _ = client.Extension("ENHANCEDSTATUSCODES") + } + if err := c.checkConn(client); err != nil { + return &SendError{ + Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err), + errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), + } } var errs []*SendError for id, message := range messages { - if sendErr := c.sendSingleMsg(message); sendErr != nil { + if sendErr := c.sendSingleMsg(client, message); sendErr != nil { messages[id].sendError = sendErr var msgSendErr *SendError @@ -50,9 +62,11 @@ func (c *Client) Send(messages ...*Msg) error { returnErr.rcpt = append(returnErr.rcpt, errs[i].rcpt...) } - // We assume that the isTemp flag from the last error we received should be the + // We assume that the error codes and flags from the last error we received should be the // indicator for the returned isTemp flag as well returnErr.isTemp = errs[len(errs)-1].isTemp + returnErr.errcode = errs[len(errs)-1].errcode + returnErr.enhancedStatusCode = errs[len(errs)-1].enhancedStatusCode return returnErr } diff --git a/client_120.go b/client_120.go index 012a4f7..622e149 100644 --- a/client_120.go +++ b/client_120.go @@ -9,26 +9,38 @@ package mail import ( "errors" + + "github.com/wneessen/go-mail/smtp" ) -// Send attempts to send one or more Msg using the Client connection to the SMTP server. -// If the Client has no active connection to the server, Send will fail with an error. For each -// of the provided Msg, it will associate a SendError with the Msg in case of a transmission -// or delivery error. +// SendWithSMTPClient attempts to send one or more Msg using a provided smtp.Client with an +// established connection to the SMTP server. If the smtp.Client has no active connection to +// the server, SendWithSMTPClient will fail with an error. For each of the provided Msg, it +// will associate a SendError with the Msg in case of a transmission or delivery error. // // This method first checks for an active connection to the SMTP server. If the connection is -// not valid, it returns an error wrapped in a SendError. It then iterates over the provided -// messages, attempting to send each one. If an error occurs during sending, the method records -// the error and associates it with the corresponding Msg. +// not valid, it returns a SendError. It then iterates over the provided messages, attempting +// to send each one. If an error occurs during sending, the method records the error and +// associates it with the corresponding Msg. If multiple errors are encountered, it aggregates +// them into a single SendError to be returned. // // Parameters: +// - client: A pointer to the smtp.Client that holds the connection to the SMTP server // - messages: A variadic list of pointers to Msg objects to be sent. // // Returns: -// - An error that aggregates any SendErrors encountered during the sending process; otherwise, returns nil. -func (c *Client) Send(messages ...*Msg) (returnErr error) { - if err := c.checkConn(); err != nil { - returnErr = &SendError{Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err)} +// - An error that represents the sending result, which may include multiple SendErrors if +// any occurred; otherwise, returns nil. +func (c *Client) SendWithSMTPClient(client *smtp.Client, messages ...*Msg) (returnErr error) { + escSupport := false + if client != nil { + escSupport, _ = client.Extension("ENHANCEDSTATUSCODES") + } + if err := c.checkConn(client); err != nil { + returnErr = &SendError{ + Reason: ErrConnCheck, errlist: []error{err}, isTemp: isTempError(err), + errcode: errorCode(err), enhancedStatusCode: enhancedStatusCode(err, escSupport), + } return } @@ -38,7 +50,7 @@ func (c *Client) Send(messages ...*Msg) (returnErr error) { }() for id, message := range messages { - if sendErr := c.sendSingleMsg(message); sendErr != nil { + if sendErr := c.sendSingleMsg(client, message); sendErr != nil { messages[id].sendError = sendErr errs = append(errs, sendErr) } diff --git a/client_test.go b/client_test.go index 664f0fd..195a797 100644 --- a/client_test.go +++ b/client_test.go @@ -17,6 +17,7 @@ import ( "net/mail" "os" "reflect" + "strconv" "strings" "sync" "sync/atomic" @@ -34,14 +35,15 @@ const ( TestServerProto = "tcp" // TestServerAddr is the address the simple SMTP test server listens on TestServerAddr = "127.0.0.1" - // TestServerPortBase is the base port for the simple SMTP test server - TestServerPortBase = 12025 // TestSenderValid is a test sender email address considered valid for sending test emails. TestSenderValid = "valid-from@domain.tld" // TestRcptValid is a test recipient email address considered valid for sending test emails. TestRcptValid = "valid-to@domain.tld" ) +// TestServerPortBase is the base port for the simple SMTP test server +var TestServerPortBase int32 = 30025 + // PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances. var PortAdder atomic.Int32 @@ -98,6 +100,18 @@ type logData struct { Lines []logLine `json:"lines"` } +func init() { + testPort := os.Getenv("TEST_BASEPORT") + if testPort == "" { + return + } + if port, err := strconv.Atoi(testPort); err == nil { + if port <= 65000 && port > 1023 { + TestServerPortBase = int32(port) + } + } +} + func TestNewClient(t *testing.T) { t.Run("create new Client", func(t *testing.T) { client, err := NewClient(DefaultHost) @@ -1647,6 +1661,15 @@ func TestClient_Close(t *testing.T) { t.Errorf("close was supposed to fail, but didn't") } }) + t.Run("close on a nil smtpclient should return nil", func(t *testing.T) { + client, err := NewClient(DefaultHost) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + if err = client.Close(); err != nil { + t.Errorf("failed to close the client: %s", err) + } + }) } func TestClient_DialWithContext(t *testing.T) { @@ -1749,11 +1772,8 @@ func TestClient_DialWithContext(t *testing.T) { t.Errorf("failed to close the client: %s", err) } }) - if client.smtpClient == nil { - t.Errorf("client with invalid HELO should still have a smtp client, got nil") - } - if !client.smtpClient.HasConnection() { - t.Errorf("client with invalid HELO should still have a smtp client connection, got nil") + if client.smtpClient != nil { + t.Error("client with invalid HELO should not have a smtp client") } }) t.Run("fail on base port and fallback", func(t *testing.T) { @@ -1802,11 +1822,8 @@ func TestClient_DialWithContext(t *testing.T) { if err = client.DialWithContext(ctxDial); err == nil { t.Fatalf("connection was supposed to fail, but didn't") } - if client.smtpClient == nil { - t.Fatalf("client has no smtp client") - } - if !client.smtpClient.HasConnection() { - t.Errorf("client has no connection") + if client.smtpClient != nil { + t.Fatalf("client is not supposed to have a smtp client") } }) t.Run("connect with failing auth", func(t *testing.T) { @@ -2274,6 +2291,84 @@ func TestClient_DialAndSendWithContext(t *testing.T) { t.Errorf("client was supposed to fail on dial") } }) + // https://github.com/wneessen/go-mail/issues/380 + t.Run("concurrent sending via DialAndSendWithContext", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + + wg := sync.WaitGroup{} + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + msg := testMessage(t) + msg.SetMessageIDWithValue("this.is.a.message.id") + + ctxDial, cancelDial := context.WithTimeout(ctx, time.Minute) + defer cancelDial() + if goroutineErr := client.DialAndSendWithContext(ctxDial, msg); goroutineErr != nil { + t.Errorf("failed to dial and send message: %s", goroutineErr) + } + }() + } + wg.Wait() + }) + // https://github.com/wneessen/go-mail/issues/385 + t.Run("concurrent sending via DialAndSendWithContext on receiver func", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + sender := testSender{client} + + ctxDial := context.Background() + wg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + wg.Add(1) + msg := testMessage(t) + go func() { + defer wg.Done() + if goroutineErr := sender.Send(ctxDial, msg); goroutineErr != nil { + t.Errorf("failed to send message: %s", goroutineErr) + } + }() + } + wg.Wait() + }) } func TestClient_auth(t *testing.T) { @@ -2281,6 +2376,11 @@ func TestClient_auth(t *testing.T) { name string authType SMTPAuthType }{ + {"LOGIN via AUTODISCOVER", SMTPAuthAutoDiscover}, + {"PLAIN via AUTODISCOVER", SMTPAuthAutoDiscover}, + {"SCRAM-SHA-1 via AUTODISCOVER", SMTPAuthAutoDiscover}, + {"SCRAM-SHA-256 via AUTODISCOVER", SMTPAuthAutoDiscover}, + {"XOAUTH2 via AUTODISCOVER", SMTPAuthAutoDiscover}, {"CRAM-MD5", SMTPAuthCramMD5}, {"LOGIN", SMTPAuthLogin}, {"LOGIN-NOENC", SMTPAuthLoginNoEnc}, @@ -2486,6 +2586,42 @@ func TestClient_auth(t *testing.T) { }) } +func TestClient_authTypeAutoDiscover(t *testing.T) { + tests := []struct { + supported string + tls bool + expect SMTPAuthType + shouldFail bool + }{ + {"LOGIN SCRAM-SHA-256 SCRAM-SHA-1 SCRAM-SHA-256-PLUS SCRAM-SHA-1-PLUS", true, SMTPAuthSCRAMSHA256PLUS, false}, + {"LOGIN SCRAM-SHA-256 SCRAM-SHA-1 SCRAM-SHA-256-PLUS SCRAM-SHA-1-PLUS", false, SMTPAuthSCRAMSHA256, false}, + {"LOGIN PLAIN SCRAM-SHA-1 SCRAM-SHA-1-PLUS", true, SMTPAuthSCRAMSHA1PLUS, false}, + {"LOGIN PLAIN SCRAM-SHA-1 SCRAM-SHA-1-PLUS", false, SMTPAuthSCRAMSHA1, false}, + {"LOGIN XOAUTH2 SCRAM-SHA-1-PLUS", false, SMTPAuthXOAUTH2, false}, + {"PLAIN LOGIN CRAM-MD5", false, SMTPAuthCramMD5, false}, + {"CRAM-MD5", false, SMTPAuthCramMD5, false}, + {"PLAIN", true, SMTPAuthPlain, false}, + {"LOGIN PLAIN", true, SMTPAuthPlain, false}, + {"LOGIN PLAIN", false, "no secure mechanism", true}, + {"", false, "supported list empty", true}, + } + for _, tt := range tests { + t.Run("AutoDiscover selects the strongest auth type: "+string(tt.expect), func(t *testing.T) { + client := &Client{smtpAuthType: SMTPAuthAutoDiscover} + authType, err := client.authTypeAutoDiscover(tt.supported, tt.tls) + if err != nil && !tt.shouldFail { + t.Fatalf("failed to auto discover auth type: %s", err) + } + if tt.shouldFail && err == nil { + t.Fatal("expected auto discover to fail") + } + if !tt.shouldFail && authType != tt.expect { + t.Errorf("expected strongest auth type: %s, got: %s", tt.expect, authType) + } + }) + } +} + func TestClient_Send(t *testing.T) { message := testMessage(t) t.Run("connect and send email", func(t *testing.T) { @@ -2606,6 +2742,82 @@ func TestClient_Send(t *testing.T) { }) } +func TestClient_DialToSMTPClientWithContext(t *testing.T) { + t.Run("establish a new client connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + smtpClient, err := client.DialToSMTPClientWithContext(ctxDial) + if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + t.Skip("failed to connect to the test server due to timeout") + } + t.Fatalf("failed to connect to test server: %s", err) + } + t.Cleanup(func() { + if err := client.CloseWithSMTPClient(smtpClient); err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + t.Skip("failed to close the test server connection due to timeout") + } + t.Errorf("failed to close client: %s", err) + } + }) + if smtpClient == nil { + t.Fatal("expected SMTP client, got nil") + } + if !smtpClient.HasConnection() { + t.Fatal("expected connection on smtp client") + } + if ok, _ := smtpClient.Extension("DSN"); !ok { + t.Error("expected DSN extension but it was not found") + } + }) + t.Run("dial to SMTP server fails on first client writeFile", func(t *testing.T) { + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + failReadWriteSeekCloser{}, + failReadWriteSeekCloser{}, + } + + ctxDial, cancelDial := context.WithTimeout(context.Background(), time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithDialContextFunc(getFakeDialFunc(fake))) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + _, err = client.DialToSMTPClientWithContext(ctxDial) + if err == nil { + t.Fatal("expected connection to fake to fail") + } + }) +} + func TestClient_sendSingleMsg(t *testing.T) { t.Run("connect and send email", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -2645,7 +2857,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err != nil { + if err = client.sendSingleMsg(client.smtpClient, message); err != nil { t.Errorf("failed to send message: %s", err) } }) @@ -2688,7 +2900,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } }) @@ -2733,7 +2945,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2783,7 +2995,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2794,7 +3006,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("expected ErrGetSender, got %s", sendErr.Reason) } }) - t.Run("fail with no recepient addresses", func(t *testing.T) { + t.Run("fail with no recipient addresses", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() PortAdder.Add(1) @@ -2833,7 +3045,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2883,7 +3095,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err != nil { + if err = client.sendSingleMsg(client.smtpClient, message); err != nil { t.Errorf("failed to send message: %s", err) } }) @@ -2926,7 +3138,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -2977,7 +3189,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3028,7 +3240,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3078,7 +3290,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3128,7 +3340,7 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.sendSingleMsg(message); err == nil { + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { t.Errorf("client should have failed to send message") } var sendErr *SendError @@ -3139,6 +3351,59 @@ func TestClient_sendSingleMsg(t *testing.T) { t.Errorf("expected ErrSMTPDataClose, got %s", sendErr.Reason) } }) + t.Run("error code and enhanced status code support", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-ENHANCEDSTATUSCODES\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnMailFrom: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + message := testMessage(t) + + ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500) + t.Cleanup(cancelDial) + + client, err := NewClient(DefaultHost, WithPort(serverPort), WithTLSPolicy(NoTLS)) + if err != nil { + t.Fatalf("failed to create new client: %s", err) + } + if err = client.DialWithContext(ctxDial); err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + t.Skip("failed to connect to the test server due to timeout") + } + t.Fatalf("failed to connect to test server: %s", err) + } + t.Cleanup(func() { + if err := client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.sendSingleMsg(client.smtpClient, message); err == nil { + t.Error("expected mail delivery to fail") + } + var sendErr *SendError + if !errors.As(err, &sendErr) { + t.Fatalf("expected SendError, got %s", err) + } + if sendErr.errcode != 500 { + t.Errorf("expected error code 500, got %d", sendErr.errcode) + } + if !strings.EqualFold(sendErr.enhancedStatusCode, "5.5.2") { + t.Errorf("expected enhanced status code 5.5.2, got %s", sendErr.enhancedStatusCode) + } + }) } func TestClient_checkConn(t *testing.T) { @@ -3178,7 +3443,7 @@ func TestClient_checkConn(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.checkConn(); err != nil { + if err = client.checkConn(client.smtpClient); err != nil { t.Errorf("failed to check connection: %s", err) } }) @@ -3219,7 +3484,7 @@ func TestClient_checkConn(t *testing.T) { t.Errorf("failed to close client: %s", err) } }) - if err = client.checkConn(); err == nil { + if err = client.checkConn(client.smtpClient); err == nil { t.Errorf("client should have failed on connection check") } if !errors.Is(err, ErrNoActiveConnection) { @@ -3231,7 +3496,7 @@ func TestClient_checkConn(t *testing.T) { if err != nil { t.Fatalf("failed to create new client: %s", err) } - if err = client.checkConn(); err == nil { + if err = client.checkConn(client.smtpClient); err == nil { t.Errorf("client should have failed on connection check") } if !errors.Is(err, ErrNoActiveConnection) { @@ -3455,24 +3720,20 @@ func TestClient_XOAuth2OnFaker(t *testing.T) { } if err = c.DialWithContext(context.Background()); err == nil { t.Fatal("expected dial error got nil") - } else { - if !errors.Is(err, ErrXOauth2AuthNotSupported) { - t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) - } + } + if !errors.Is(err, ErrXOauth2AuthNotSupported) { + t.Fatalf("expected %v; got %v", ErrXOauth2AuthNotSupported, err) } if err = c.Close(); err != nil { t.Fatalf("disconnect from test server failed: %v", err) } client := strings.Split(wrote.String(), "\r\n") - if len(client) != 3 { - t.Fatalf("unexpected number of client requests got %d; want 3", len(client)) + if len(client) != 2 { + t.Fatalf("unexpected number of client requests got %d; want 2", len(client)) } if !strings.HasPrefix(client[0], "EHLO") { t.Fatalf("expected EHLO, got %q", client[0]) } - if client[1] != "QUIT" { - t.Fatalf("expected QUIT, got %q", client[3]) - } }) } @@ -3496,6 +3757,17 @@ func (f faker) SetDeadline(time.Time) error { return nil } func (f faker) SetReadDeadline(time.Time) error { return nil } func (f faker) SetWriteDeadline(time.Time) error { return nil } +type testSender struct { + client *Client +} + +func (t *testSender) Send(ctx context.Context, m *Msg) error { + if err := t.client.DialAndSendWithContext(ctx, m); err != nil { + return fmt.Errorf("failed to dial and send mail: %w", err) + } + return nil +} + // parseJSONLog parses a JSON encoded log from the provided buffer and returns a slice of logLine structs. // In case of a decode error, it reports the error to the testing framework. func parseJSONLog(t *testing.T, buf *bytes.Buffer) logData { @@ -3548,6 +3820,8 @@ func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", " // serverProps represents the configuration properties for the SMTP server. type serverProps struct { + BufferMutex sync.RWMutex + EchoBuffer io.Writer FailOnAuth bool FailOnDataInit bool FailOnDataClose bool @@ -3637,6 +3911,13 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server if err != nil { t.Logf("failed to write line: %s", err) } + if props.EchoBuffer != nil { + props.BufferMutex.Lock() + if _, berr := props.EchoBuffer.Write([]byte(data + "\r\n")); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) + } + props.BufferMutex.Unlock() + } _ = writer.Flush() } writeOK := func() { @@ -3653,6 +3934,13 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server break } time.Sleep(time.Millisecond) + if props.EchoBuffer != nil { + props.BufferMutex.Lock() + if _, berr := props.EchoBuffer.Write([]byte(data)); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) + } + props.BufferMutex.Unlock() + } var datastring string data = strings.TrimSpace(data) @@ -3713,6 +4001,13 @@ func handleTestServerConnection(connection net.Conn, t *testing.T, props *server t.Logf("failed to read data from connection: %s", derr) break } + if props.EchoBuffer != nil { + props.BufferMutex.Lock() + if _, berr := props.EchoBuffer.Write([]byte(ddata)); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) + } + props.BufferMutex.Unlock() + } ddata = strings.TrimSpace(ddata) if ddata == "." { if props.FailOnDataClose { diff --git a/codecov.yml b/codecov.yml index a9f998e..7c99de7 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,19 +1,19 @@ # SPDX-FileCopyrightText: 2022-2023 The go-mail Authors # -# SPDX-License-Identifier: CC0-1.0 +# SPDX-License-Identifier: MIT coverage: status: project: default: - target: 90% + target: 95% threshold: 2% base: auto if_ci_failed: error only_pulls: false patch: default: - target: 90% + target: 95% base: auto if_ci_failed: error threshold: 2% diff --git a/doc.go b/doc.go index 88a12ee..6878d86 100644 --- a/doc.go +++ b/doc.go @@ -11,4 +11,4 @@ package mail // VERSION indicates the current version of the package. It is also attached to the default user // agent string. -const VERSION = "0.5.1" +const VERSION = "0.5.2" diff --git a/go.mod b/go.mod index 84fc4bd..b98e082 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,6 @@ module github.com/wneessen/go-mail go 1.16 require ( - golang.org/x/crypto v0.28.0 - golang.org/x/text v0.19.0 -) \ No newline at end of file + golang.org/x/crypto v0.30.0 + golang.org/x/text v0.21.0 +) diff --git a/go.sum b/go.sum index 2eb65e0..c80524b 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.30.0 h1:RwoQn3GkWiMkzlX562cLB7OxWvjH1L8xutO2WoJcRoY= +golang.org/x/crypto v0.30.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -26,7 +26,7 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -37,7 +37,7 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -46,7 +46,7 @@ golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= -golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -55,8 +55,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/msg.go b/msg.go index 36141c3..d2583f1 100644 --- a/msg.go +++ b/msg.go @@ -16,6 +16,7 @@ import ( "fmt" ht "html/template" "io" + "io/fs" "mime" "net/mail" "os" @@ -2033,9 +2034,28 @@ func (m *Msg) AttachTextTemplate( // - https://datatracker.ietf.org/doc/html/rfc2183 func (m *Msg) AttachFromEmbedFS(name string, fs *embed.FS, opts ...FileOption) error { if fs == nil { - return fmt.Errorf("embed.FS must not be nil") + return errors.New("embed.FS must not be nil") } - file, err := fileFromEmbedFS(name, fs) + return m.AttachFromIOFS(name, *fs, opts...) +} + +// AttachFromIOFS attaches a file from a generic file system to the message. +// +// This function retrieves a file by name from an fs.FS instance, processes it, and appends it to the +// message's attachment collection. Additional file options can be provided for further customization. +// +// Parameters: +// - name: The name of the file to retrieve from the file system. +// - iofs: The file system (must not be nil). +// - opts: Optional file options to customize the attachment process. +// +// Returns: +// - An error if the file cannot be retrieved, the fs.FS is nil, or any other issue occurs. +func (m *Msg) AttachFromIOFS(name string, iofs fs.FS, opts ...FileOption) error { + if iofs == nil { + return errors.New("fs.FS must not be nil") + } + file, err := fileFromIOFS(name, iofs) if err != nil { return err } @@ -2179,9 +2199,28 @@ func (m *Msg) EmbedTextTemplate( // - https://datatracker.ietf.org/doc/html/rfc2183 func (m *Msg) EmbedFromEmbedFS(name string, fs *embed.FS, opts ...FileOption) error { if fs == nil { - return fmt.Errorf("embed.FS must not be nil") + return errors.New("embed.FS must not be nil") } - file, err := fileFromEmbedFS(name, fs) + return m.EmbedFromIOFS(name, *fs, opts...) +} + +// EmbedFromIOFS embeds a file from a generic file system into the message. +// +// This function retrieves a file by name from an fs.FS instance, processes it, and appends it to the +// message's embed collection. Additional file options can be provided for further customization. +// +// Parameters: +// - name: The name of the file to retrieve from the file system. +// - iofs: The file system (must not be nil). +// - opts: Optional file options to customize the embedding process. +// +// Returns: +// - An error if the file cannot be retrieved, the fs.FS is nil, or any other issue occurs. +func (m *Msg) EmbedFromIOFS(name string, iofs fs.FS, opts ...FileOption) error { + if iofs == nil { + return errors.New("fs.FS must not be nil") + } + file, err := fileFromIOFS(name, iofs) if err != nil { return err } @@ -2788,15 +2827,15 @@ func (m *Msg) addDefaultHeader() { m.SetGenHeader(HeaderMIMEVersion, string(m.mimever)) } -// fileFromEmbedFS returns a File pointer from a given file in the provided embed.FS. +// fileFromIOFS returns a File pointer from a given file in the provided fs.FS. // -// This method retrieves a file from the embedded filesystem (embed.FS) and returns a File structure +// This method retrieves a file from the provided io/fs (fs.FS) and returns a File structure // that can be used as an attachment or embed in the email message. The file's content is read when // writing to an io.Writer, and the file is identified by its base name. // // Parameters: // - name: The name of the file to retrieve from the embedded filesystem. -// - fs: A pointer to the embed.FS from which the file will be opened. +// - fs: An instance that satisfies the fs.FS interface // // Returns: // - A pointer to the File structure representing the embedded file. @@ -2804,23 +2843,27 @@ func (m *Msg) addDefaultHeader() { // // References: // - https://datatracker.ietf.org/doc/html/rfc2183 -func fileFromEmbedFS(name string, fs *embed.FS) (*File, error) { - _, err := fs.Open(name) +func fileFromIOFS(name string, iofs fs.FS) (*File, error) { + if iofs == nil { + return nil, errors.New("fs.FS is nil") + } + + _, err := iofs.Open(name) if err != nil { - return nil, fmt.Errorf("failed to open file from embed.FS: %w", err) + return nil, fmt.Errorf("failed to open file from fs.FS: %w", err) } return &File{ Name: filepath.Base(name), Header: make(map[string][]string), Writer: func(writer io.Writer) (int64, error) { - file, err := fs.Open(name) - if err != nil { - return 0, err + file, ferr := iofs.Open(name) + if ferr != nil { + return 0, fmt.Errorf("failed to open file from fs.FS: %w", ferr) } - numBytes, err := io.Copy(writer, file) - if err != nil { + numBytes, ferr := io.Copy(writer, file) + if ferr != nil { _ = file.Close() - return numBytes, fmt.Errorf("failed to copy file to io.Writer: %w", err) + return numBytes, fmt.Errorf("failed to copy file from fs.FS to io.Writer: %w", ferr) } return numBytes, file.Close() }, diff --git a/msg_test.go b/msg_test.go index 563f217..c1bce59 100644 --- a/msg_test.go +++ b/msg_test.go @@ -126,8 +126,8 @@ var ( {`" "@domain.tld`, true}, // Still valid, since quoted {`"<\"@\".!#%$@domain.tld"`, false}, // Quoting with illegal characters is not allowed {`<\"@\\".!#%$@domain.tld`, false}, // Still a bunch of random illegal characters - {`hi"@"there@domain.tld`, false}, // Quotes must be dot-seperated - {`"<\"@\\".!.#%$@domain.tld`, false}, // Quote is escaped and dot-seperated which would be RFC822 compliant, but not RFC5322 compliant + {`hi"@"there@domain.tld`, false}, // Quotes must be dot-separated + {`"<\"@\\".!.#%$@domain.tld`, false}, // Quote is escaped and dot-separated which would be RFC822 compliant, but not RFC5322 compliant {`hi\ there@domain.tld`, false}, // Spaces must be quoted {"hello@tld", true}, // TLD is enough {`äŊ åĨŊ@域名.éĄļįē§åŸŸå`, true}, // We speak RFC6532 @@ -4527,12 +4527,12 @@ func TestMsg_AttachFile(t *testing.T) { t.Errorf("expected message body to be %s, got: %s", "This is a test attachment", got) } }) - t.Run("AttachFile with non-existant file", func(t *testing.T) { + t.Run("AttachFile with non-existent file", func(t *testing.T) { message := NewMsg() if message == nil { t.Fatal("message is nil") } - message.AttachFile("testdata/non-existant-file.txt") + message.AttachFile("testdata/non-existent-file.txt") attachments := message.GetAttachments() if len(attachments) != 0 { t.Fatalf("failed to retrieve attachments list") @@ -4970,6 +4970,75 @@ func TestMsg_AttachFromEmbedFS(t *testing.T) { }) } +func TestMsg_AttachFromIOFS(t *testing.T) { + t.Run("AttachFromIOFS successful", func(t *testing.T) { + message := NewMsg() + if message == nil { + t.Fatal("message is nil") + } + if err := message.AttachFromIOFS("testdata/attachment.txt", efs, + WithFileName("attachment.txt")); err != nil { + t.Fatalf("failed to attach from embed FS: %s", err) + } + attachments := message.GetAttachments() + if len(attachments) != 1 { + t.Fatalf("failed to retrieve attachments list") + } + if attachments[0] == nil { + t.Fatal("expected attachment to be not nil") + } + if attachments[0].Name != "attachment.txt" { + t.Errorf("expected attachment name to be %s, got: %s", "attachment.txt", attachments[0].Name) + } + messageBuf := bytes.NewBuffer(nil) + _, err := attachments[0].Writer(messageBuf) + if err != nil { + t.Errorf("writer func failed: %s", err) + } + got := strings.TrimSpace(messageBuf.String()) + if !strings.EqualFold(got, "This is a test attachment") { + t.Errorf("expected message body to be %s, got: %s", "This is a test attachment", got) + } + }) + t.Run("AttachFromIOFS with invalid path", func(t *testing.T) { + message := NewMsg() + if message == nil { + t.Fatal("message is nil") + } + err := message.AttachFromIOFS("testdata/invalid.txt", efs, WithFileName("attachment.txt")) + if err == nil { + t.Fatal("expected error, got nil") + } + }) + t.Run("AttachFromIOFS with nil embed FS", func(t *testing.T) { + message := NewMsg() + if message == nil { + t.Fatal("message is nil") + } + err := message.AttachFromIOFS("testdata/invalid.txt", nil, WithFileName("attachment.txt")) + if err == nil { + t.Fatal("expected error, got nil") + } + }) + t.Run("AttachFromIOFS with fs.FS fails on copy", func(t *testing.T) { + message := NewMsg() + if message == nil { + t.Fatal("message is nil") + } + if err := message.AttachFromIOFS("testdata/attachment.txt", efs); err != nil { + t.Fatalf("failed to attach file from fs.FS: %s", err) + } + attachments := message.GetAttachments() + if len(attachments) != 1 { + t.Fatalf("failed to get attachments, expected 1, got: %d", len(attachments)) + } + _, err := attachments[0].Writer(failReadWriteSeekCloser{}) + if err == nil { + t.Error("writer func expected to fail, but didn't") + } + }) +} + func TestMsg_EmbedFile(t *testing.T) { t.Run("EmbedFile with file", func(t *testing.T) { message := NewMsg() @@ -4997,12 +5066,12 @@ func TestMsg_EmbedFile(t *testing.T) { t.Errorf("expected message body to be %s, got: %s", "This is a test embed", got) } }) - t.Run("EmbedFile with non-existant file", func(t *testing.T) { + t.Run("EmbedFile with non-existent file", func(t *testing.T) { message := NewMsg() if message == nil { t.Fatal("message is nil") } - message.EmbedFile("testdata/non-existant-file.txt") + message.EmbedFile("testdata/non-existent-file.txt") embeds := message.GetEmbeds() if len(embeds) != 0 { t.Fatalf("failed to retrieve attachments list") @@ -5435,6 +5504,58 @@ func TestMsg_EmbedFromEmbedFS(t *testing.T) { }) } +func TestMsg_EmbedFromIOFS(t *testing.T) { + t.Run("EmbedFromIOFS successful", func(t *testing.T) { + message := NewMsg() + if message == nil { + t.Fatal("message is nil") + } + if err := message.EmbedFromIOFS("testdata/embed.txt", efs, + WithFileName("embed.txt")); err != nil { + t.Fatalf("failed to embed from embed FS: %s", err) + } + embeds := message.GetEmbeds() + if len(embeds) != 1 { + t.Fatalf("failed to retrieve embeds list") + } + if embeds[0] == nil { + t.Fatal("expected embed to be not nil") + } + if embeds[0].Name != "embed.txt" { + t.Errorf("expected embed name to be %s, got: %s", "embed.txt", embeds[0].Name) + } + messageBuf := bytes.NewBuffer(nil) + _, err := embeds[0].Writer(messageBuf) + if err != nil { + t.Errorf("writer func failed: %s", err) + } + got := strings.TrimSpace(messageBuf.String()) + if !strings.EqualFold(got, "This is a test embed") { + t.Errorf("expected message body to be %s, got: %s", "This is a test embed", got) + } + }) + t.Run("EmbedFromIOFS with invalid path", func(t *testing.T) { + message := NewMsg() + if message == nil { + t.Fatal("message is nil") + } + err := message.EmbedFromIOFS("testdata/invalid.txt", efs, WithFileName("embed.txt")) + if err == nil { + t.Fatal("expected error, got nil") + } + }) + t.Run("EmbedFromIOFS with nil embed FS", func(t *testing.T) { + message := NewMsg() + if message == nil { + t.Fatal("message is nil") + } + err := message.EmbedFromIOFS("testdata/invalid.txt", nil, WithFileName("embed.txt")) + if err == nil { + t.Fatal("expected error, got nil") + } + }) +} + func TestMsg_Reset(t *testing.T) { message := NewMsg() if message == nil { @@ -6537,6 +6658,15 @@ func TestMsg_addDefaultHeader(t *testing.T) { }) } +func TestMsg_fileFromIOFS(t *testing.T) { + t.Run("file from fs.FS where fs is nil ", func(t *testing.T) { + _, err := fileFromIOFS("testfile.txt", nil) + if err == nil { + t.Fatal("expected error for fs.FS that is nil") + } + }) +} + // TestSignWithSMime_ValidRSAKeyPair tests WithSMimeSinging with given rsa key pair func TestSignWithSMime_ValidRSAKeyPair(t *testing.T) { privateKey, certificate, intermediateCertificate, err := getDummyRSACryptoMaterial() diff --git a/msgwriter.go b/msgwriter.go index 331138c..9245df4 100644 --- a/msgwriter.go +++ b/msgwriter.go @@ -281,7 +281,7 @@ func (mw *msgWriter) addFiles(files []*File, isAttachment bool) { mimeType = string(file.ContentType) } file.setHeader(HeaderContentType, fmt.Sprintf(`%s; name="%s"`, mimeType, - mw.encoder.Encode(mw.charset.String(), file.Name))) + mw.encoder.Encode(mw.charset.String(), sanitizeFilename(file.Name)))) } if _, ok := file.getHeader(HeaderContentTransferEnc); !ok { @@ -293,7 +293,7 @@ func (mw *msgWriter) addFiles(files []*File, isAttachment bool) { if file.Desc != "" { if _, ok := file.getHeader(HeaderContentDescription); !ok { - file.setHeader(HeaderContentDescription, file.Desc) + file.setHeader(HeaderContentDescription, mw.encoder.Encode(mw.charset.String(), file.Desc)) } } @@ -303,12 +303,12 @@ func (mw *msgWriter) addFiles(files []*File, isAttachment bool) { disposition = "attachment" } file.setHeader(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, - disposition, mw.encoder.Encode(mw.charset.String(), file.Name))) + disposition, mw.encoder.Encode(mw.charset.String(), sanitizeFilename(file.Name)))) } if !isAttachment { if _, ok := file.getHeader(HeaderContentID); !ok { - file.setHeader(HeaderContentID, fmt.Sprintf("<%s>", file.Name)) + file.setHeader(HeaderContentID, fmt.Sprintf("<%s>", sanitizeFilename(file.Name))) } } if mw.depth == 0 { @@ -511,3 +511,33 @@ func (mw *msgWriter) writeBody(writeFunc func(io.Writer) (int64, error), encodin mw.bytesWritten += n } } + +// sanitizeFilename sanitizes a given filename string by replacing specific unwanted characters with +// an underscore ('_'). +// +// This method replaces any control character and any special character that is problematic for +// MIME headers and file systems with an underscore ('_') character. +// +// The following characters are replaced +// - Any control character (US-ASCII < 32) +// - ", /, :, <, >, ?, \, |, [DEL] +// +// Parameters: +// - input: A string of a filename that is supposed to be sanitized +// +// Returns: +// - A string representing the sanitized version of the filename +func sanitizeFilename(input string) string { + var sanitized strings.Builder + for i := 0; i < len(input); i++ { + // We do not allow control characters in file names. + if input[i] < 32 || input[i] == 34 || input[i] == 47 || input[i] == 58 || + input[i] == 60 || input[i] == 62 || input[i] == 63 || input[i] == 92 || + input[i] == 124 || input[i] == 127 { + sanitized.WriteRune('_') + continue + } + sanitized.WriteByte(input[i]) + } + return sanitized.String() +} diff --git a/msgwriter_test.go b/msgwriter_test.go index 139454c..09bfa55 100644 --- a/msgwriter_test.go +++ b/msgwriter_test.go @@ -304,6 +304,65 @@ func TestMsgWriter_addFiles(t *testing.T) { charset: CharsetUTF8, encoder: getEncoder(EncodingQP), } + tests := []struct { + name string + filename string + expect string + }{ + {"normal US-ASCII filename", "test.txt", "test.txt"}, + {"normal US-ASCII filename with space", "test file.txt", "test file.txt"}, + {"filename with new lines", "test\r\n.txt", "test__.txt"}, + {"filename with disallowed character:\x22", "test\x22.txt", "test_.txt"}, + {"filename with disallowed character:\x2f", "test\x2f.txt", "test_.txt"}, + {"filename with disallowed character:\x3a", "test\x3a.txt", "test_.txt"}, + {"filename with disallowed character:\x3c", "test\x3c.txt", "test_.txt"}, + {"filename with disallowed character:\x3e", "test\x3e.txt", "test_.txt"}, + {"filename with disallowed character:\x3f", "test\x3f.txt", "test_.txt"}, + {"filename with disallowed character:\x5c", "test\x5c.txt", "test_.txt"}, + {"filename with disallowed character:\x7c", "test\x7c.txt", "test_.txt"}, + {"filename with disallowed character:\x7f", "test\x7f.txt", "test_.txt"}, + { + "japanese characters filename", "æˇģäģ˜ãƒ•ã‚Ąã‚¤ãƒĢ.txt", + "=?UTF-8?q?=E6=B7=BB=E4=BB=98=E3=83=95=E3=82=A1=E3=82=A4=E3=83=AB.txt?=", + }, + { + "simplified chinese characters filename", "æĩ‹č¯•é™„äģļ文äģļ.txt", + "=?UTF-8?q?=E6=B5=8B=E8=AF=95=E9=99=84=E4=BB=B6=E6=96=87=E4=BB=B6.txt?=", + }, + { + "cyrillic characters filename", "ĐĸĐĩŅŅ‚ОвŅ‹Đš ĐŋŅ€Đ¸ĐēŅ€ĐĩĐŋĐģĐĩĐŊĐŊŅ‹Đš Ņ„Đ°ĐšĐģ.txt", + "=?UTF-8?q?=D0=A2=D0=B5=D1=81=D1=82=D0=BE=D0=B2=D1=8B=D0=B9_=D0=BF=D1=80?= " + + "=?UTF-8?q?=D0=B8=D0=BA=D1=80=D0=B5=D0=BF=D0=BB=D0=B5=D0=BD=D0=BD=D1=8B?= " + + "=?UTF-8?q?=D0=B9_=D1=84=D0=B0=D0=B9=D0=BB.txt?=", + }, + } + for _, tt := range tests { + t.Run("addFile with filename sanitization: "+tt.name, func(t *testing.T) { + buffer := bytes.NewBuffer(nil) + msgwriter.writer = buffer + message := testMessage(t) + message.AttachFile("testdata/attachment.txt", WithFileName(tt.filename)) + msgwriter.writeMsg(message) + if msgwriter.err != nil { + t.Errorf("msgWriter failed to write: %s", msgwriter.err) + } + + var ctExpect string + cdExpect := fmt.Sprintf(`Content-Disposition: attachment; filename="%s"`, tt.expect) + switch runtime.GOOS { + case "freebsd": + ctExpect = fmt.Sprintf(`Content-Type: application/octet-stream; name="%s"`, tt.expect) + default: + ctExpect = fmt.Sprintf(`Content-Type: text/plain; charset=utf-8; name="%s"`, tt.expect) + } + if !strings.Contains(buffer.String(), ctExpect) { + t.Errorf("expected content-type: %q, got: %q", ctExpect, buffer.String()) + } + if !strings.Contains(buffer.String(), cdExpect) { + t.Errorf("expected content-disposition: %q, got: %q", cdExpect, buffer.String()) + } + }) + } t.Run("message with a single file attached", func(t *testing.T) { buffer := bytes.NewBuffer(nil) msgwriter.writer = buffer @@ -324,7 +383,7 @@ func TestMsgWriter_addFiles(t *testing.T) { } } if !strings.Contains(buffer.String(), `Content-Disposition: attachment; filename="attachment.txt"`) { - t.Errorf("Content-Dispositon header not found for attachment. Mail: %s", buffer.String()) + t.Errorf("Content-Disposition header not found for attachment. Mail: %s", buffer.String()) } switch runtime.GOOS { case "freebsd": @@ -357,7 +416,7 @@ func TestMsgWriter_addFiles(t *testing.T) { } } if !strings.Contains(buffer.String(), `Content-Disposition: attachment; filename="attachment"`) { - t.Errorf("Content-Dispositon header not found for attachment. Mail: %s", buffer.String()) + t.Errorf("Content-Disposition header not found for attachment. Mail: %s", buffer.String()) } if !strings.Contains(buffer.String(), `Content-Type: application/octet-stream; name="attachment"`) { t.Errorf("Content-Type header not found for attachment. Mail: %s", buffer.String()) @@ -383,7 +442,7 @@ func TestMsgWriter_addFiles(t *testing.T) { } } if !strings.Contains(buffer.String(), `Content-Disposition: attachment; filename="attachment.txt"`) { - t.Errorf("Content-Dispositon header not found for attachment. Mail: %s", buffer.String()) + t.Errorf("Content-Disposition header not found for attachment. Mail: %s", buffer.String()) } if !strings.Contains(buffer.String(), `Content-Type: application/octet-stream; name="attachment.txt"`) { t.Errorf("Content-Type header not found for attachment. Mail: %s", buffer.String()) @@ -402,7 +461,7 @@ func TestMsgWriter_addFiles(t *testing.T) { t.Errorf("attachment not found in mail message. Mail: %s", buffer.String()) } if !strings.Contains(buffer.String(), `Content-Disposition: attachment; filename="attachment.txt"`) { - t.Errorf("Content-Dispositon header not found for attachment. Mail: %s", buffer.String()) + t.Errorf("Content-Disposition header not found for attachment. Mail: %s", buffer.String()) } switch runtime.GOOS { case "freebsd": @@ -438,7 +497,7 @@ func TestMsgWriter_addFiles(t *testing.T) { } } if !strings.Contains(buffer.String(), `Content-Disposition: attachment; filename="attachment.txt"`) { - t.Errorf("Content-Dispositon header not found for attachment. Mail: %s", buffer.String()) + t.Errorf("Content-Disposition header not found for attachment. Mail: %s", buffer.String()) } switch runtime.GOOS { case "freebsd": @@ -478,7 +537,7 @@ func TestMsgWriter_addFiles(t *testing.T) { } } if !strings.Contains(buffer.String(), `Content-Disposition: attachment; filename="attachment.txt"`) { - t.Errorf("Content-Dispositon header not found for attachment. Mail: %s", buffer.String()) + t.Errorf("Content-Disposition header not found for attachment. Mail: %s", buffer.String()) } switch runtime.GOOS { case "freebsd": @@ -620,7 +679,7 @@ func TestMsgWriter_writeBody(t *testing.T) { buffer := bytes.NewBuffer(nil) msgwriter.writer = buffer message := testMessage(t) - msgwriter.writeBody(message.parts[0].writeFunc, NoEncoding, false) + msgwriter.writeBody(message.parts[0].writeFunc, NoEncoding) if msgwriter.err != nil { t.Errorf("writeBody failed to write: %s", msgwriter.err) } @@ -628,7 +687,7 @@ func TestMsgWriter_writeBody(t *testing.T) { t.Run("writeBody on NoEncoding fails on write", func(t *testing.T) { msgwriter.writer = failReadWriteSeekCloser{} message := testMessage(t) - msgwriter.writeBody(message.parts[0].writeFunc, NoEncoding, false) + msgwriter.writeBody(message.parts[0].writeFunc, NoEncoding) if msgwriter.err == nil { t.Errorf("writeBody succeeded, expected error") } @@ -642,7 +701,7 @@ func TestMsgWriter_writeBody(t *testing.T) { writeFunc := func(io.Writer) (int64, error) { return 0, errors.New("intentional write failure") } - msgwriter.writeBody(writeFunc, NoEncoding, false) + msgwriter.writeBody(writeFunc, NoEncoding) if msgwriter.err == nil { t.Errorf("writeBody succeeded, expected error") } @@ -653,7 +712,7 @@ func TestMsgWriter_writeBody(t *testing.T) { t.Run("writeBody Quoted-Printable fails on write", func(t *testing.T) { msgwriter.writer = failReadWriteSeekCloser{} message := testMessage(t) - msgwriter.writeBody(message.parts[0].writeFunc, EncodingQP, false) + msgwriter.writeBody(message.parts[0].writeFunc, EncodingQP) if msgwriter.err == nil { t.Errorf("writeBody succeeded, expected error") } @@ -677,6 +736,36 @@ func TestMsgWriter_writeBody(t *testing.T) { }) } +func TestMsgWriter_sanitizeFilename(t *testing.T) { + tests := []struct { + given string + want string + }{ + {"test.txt", "test.txt"}, + {"test file.txt", "test file.txt"}, + {"test\\ file.txt", "test_ file.txt"}, + {`"test" file.txt`, "_test_ file.txt"}, + {`test file .txt`, "test_file_.txt"}, + {"test\r\nfile.txt", "test__file.txt"}, + {"test\x22file.txt", "test_file.txt"}, + {"test\x2ffile.txt", "test_file.txt"}, + {"test\x3afile.txt", "test_file.txt"}, + {"test\x3cfile.txt", "test_file.txt"}, + {"test\x3efile.txt", "test_file.txt"}, + {"test\x3ffile.txt", "test_file.txt"}, + {"test\x5cfile.txt", "test_file.txt"}, + {"test\x7cfile.txt", "test_file.txt"}, + {"test\x7ffile.txt", "test_file.txt"}, + } + for _, tt := range tests { + t.Run(tt.given+"=>"+tt.want, func(t *testing.T) { + if got := sanitizeFilename(tt.given); got != tt.want { + t.Errorf("sanitizeFilename failed, expected: %q, got: %q", tt.want, got) + } + }) + } +} + // TestMsgWriter_writeMsg_SMime tests the writeMsg method of the msgWriter with S/MIME types set func TestMsgWriter_writeMsg_SMime(t *testing.T) { privateKey, certificate, intermediateCertificate, err := getDummyRSACryptoMaterial() @@ -715,3 +804,5 @@ func TestMsgWriter_writeMsg_SMime(t *testing.T) { t.Errorf("writeMsg failed. Unable to find Content-Type") } } + + diff --git a/quicksend.go b/quicksend.go new file mode 100644 index 0000000..204971f --- /dev/null +++ b/quicksend.go @@ -0,0 +1,111 @@ +// SPDX-FileCopyrightText: 2024 The go-mail Authors +// +// SPDX-License-Identifier: MIT + +package mail + +import ( + "bytes" + "crypto/tls" + "fmt" + "net" + "strconv" +) + +type AuthData struct { + Auth bool + Username string + Password string +} + +var testHookTLSConfig func() *tls.Config // nil, except for tests + +// QuickSend is an all-in-one method for quickly sending simple text mails in go-mail. +// +// This method will create a new client that connects to the server at addr, switches to TLS if possible, +// authenticates with the optional AuthData provided in auth and create a new simple Msg with the provided +// subject string and message bytes as body. The message will be sent using from as sender address and will +// be delivered to every address in rcpts. QuickSend will always send as text/plain ContentType. +// +// For the SMTP authentication, if auth is not nil and AuthData.Auth is set to true, it will try to +// autodiscover the best SMTP authentication mechanism supported by the server. If auth is set to true +// but autodiscover is not able to find a suitable authentication mechanism or if the authentication +// fails, the mail delivery will fail completely. +// +// The content parameter should be an RFC 822-style email body. The lines of content should be CRLF terminated. +// +// Parameters: +// - addr: The hostname and port of the mail server, it must include a port, as in "mail.example.com:smtp". +// - auth: A AuthData pointer. If nil or if AuthData.Auth is set to false, not SMTP authentication will be performed. +// - from: The from address of the sender as string. +// - rcpts: A slice of strings of receipient addresses. +// - subject: The subject line as string. +// - content: A byte slice of the mail content +// +// Returns: +// - A pointer to the generated Msg. +// - An error if any step in the process of mail generation or delivery failed. +func QuickSend(addr string, auth *AuthData, from string, rcpts []string, subject string, content []byte) (*Msg, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("failed to split host and port from address: %w", err) + } + portnum, err := strconv.Atoi(port) + if err != nil { + return nil, fmt.Errorf("failed to convert port to int: %w", err) + } + client, err := NewClient(host, WithPort(portnum), WithTLSPolicy(TLSOpportunistic)) + if err != nil { + return nil, fmt.Errorf("failed to create new client: %w", err) + } + + if auth != nil && auth.Auth { + client.SetSMTPAuth(SMTPAuthAutoDiscover) + client.SetUsername(auth.Username) + client.SetPassword(auth.Password) + } + + tlsConfig := client.tlsconfig + if testHookTLSConfig != nil { + tlsConfig = testHookTLSConfig() + } + if err = client.SetTLSConfig(tlsConfig); err != nil { + return nil, fmt.Errorf("failed to set TLS config: %w", err) + } + + message := NewMsg() + if err = message.From(from); err != nil { + return nil, fmt.Errorf("failed to set MAIL FROM address: %w", err) + } + if err = message.To(rcpts...); err != nil { + return nil, fmt.Errorf("failed to set RCPT TO address: %w", err) + } + message.Subject(subject) + buffer := bytes.NewBuffer(content) + writeFunc := writeFuncFromBuffer(buffer) + message.SetBodyWriter(TypeTextPlain, writeFunc) + + if err = client.DialAndSend(message); err != nil { + return nil, fmt.Errorf("failed to dial and send message: %w", err) + } + return message, nil +} + +// NewAuthData creates a new AuthData instance with the provided username and password. +// +// This function initializes an AuthData struct with authentication enabled and sets the +// username and password fields. +// +// Parameters: +// - user: The username for authentication. +// - pass: The password for authentication. +// +// Returns: +// - A pointer to the initialized AuthData instance. +func NewAuthData(user, pass string) *AuthData { + return &AuthData{ + Auth: true, + Username: user, + Password: pass, + } +} diff --git a/quicksend_test.go b/quicksend_test.go new file mode 100644 index 0000000..497e2a0 --- /dev/null +++ b/quicksend_test.go @@ -0,0 +1,368 @@ +// SPDX-FileCopyrightText: 2024 The go-mail Authors +// +// SPDX-License-Identifier: MIT + +package mail + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "strings" + "testing" + "time" +) + +func TestNewAuthData(t *testing.T) { + t.Run("AuthData with username and password", func(t *testing.T) { + auth := NewAuthData("username", "password") + if !auth.Auth { + t.Fatal("expected auth to be true") + } + if auth.Username != "username" { + t.Fatalf("expected username to be %s, got %s", "username", auth.Username) + } + if auth.Password != "password" { + t.Fatalf("expected password to be %s, got %s", "password", auth.Password) + } + }) + t.Run("AuthData with username and empty password", func(t *testing.T) { + auth := NewAuthData("username", "") + if !auth.Auth { + t.Fatal("expected auth to be true") + } + if auth.Username != "username" { + t.Fatalf("expected username to be %s, got %s", "username", auth.Username) + } + if auth.Password != "" { + t.Fatalf("expected password to be %s, got %s", "", auth.Password) + } + }) + t.Run("AuthData with empty username and set password", func(t *testing.T) { + auth := NewAuthData("", "password") + if !auth.Auth { + t.Fatal("expected auth to be true") + } + if auth.Username != "" { + t.Fatalf("expected username to be %s, got %s", "", auth.Username) + } + if auth.Password != "password" { + t.Fatalf("expected password to be %s, got %s", "password", auth.Password) + } + }) + t.Run("AuthData with empty data", func(t *testing.T) { + auth := NewAuthData("", "") + if !auth.Auth { + t.Fatal("expected auth to be true") + } + if auth.Username != "" { + t.Fatalf("expected username to be %s, got %s", "", auth.Username) + } + if auth.Password != "" { + t.Fatalf("expected password to be %s, got %s", "", auth.Password) + } + }) +} + +func TestQuickSend(t *testing.T) { + subject := "This is a test subject" + body := []byte("This is a test body\r\nWith multiple lines\r\n\r\nBest,\r\n The go-mail team") + sender := TestSenderValid + rcpts := []string{TestRcptValid} + t.Run("QuickSend with authentication and TLS", func(t *testing.T) { + ctxAuth, cancelAuth := context.WithCancel(context.Background()) + defer cancelAuth() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250-STARTTLS\r\n250 SMTPUTF8" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctxAuth, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := TestServerAddr + ":" + fmt.Sprint(serverPort) + testHookTLSConfig = func() *tls.Config { return &tls.Config{InsecureSkipVerify: true} } + + _, err := QuickSend(addr, NewAuthData("username", "password"), sender, rcpts, subject, body) + if err != nil { + t.Fatalf("failed to send email: %s", err) + } + + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + + expects := []struct { + line int + data string + }{ + {8, "STARTTLS"}, + {17, "AUTH PLAIN AHVzZXJuYW1lAHBhc3N3b3Jk"}, + {21, "MAIL FROM: BODY=8BITMIME SMTPUTF8"}, + {23, "RCPT TO:"}, + {30, "Subject: " + subject}, + {33, "From: "}, + {34, "To: "}, + {35, "Content-Type: text/plain; charset=UTF-8"}, + {36, "Content-Transfer-Encoding: quoted-printable"}, + {38, "This is a test body"}, + {39, "With multiple lines"}, + {40, ""}, + {41, "Best,"}, + {42, " The go-mail team"}, + } + for _, expect := range expects { + if !strings.EqualFold(resp[expect.line], expect.data) { + t.Errorf("expected %q at line %d, got: %q", expect.data, expect.line, resp[expect.line]) + } + } + }) + t.Run("QuickSend with authentication and TLS and multiple receipients", func(t *testing.T) { + ctxAuth, cancelAuth := context.WithCancel(context.Background()) + defer cancelAuth() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250-STARTTLS\r\n250 SMTPUTF8" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctxAuth, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := TestServerAddr + ":" + fmt.Sprint(serverPort) + testHookTLSConfig = func() *tls.Config { return &tls.Config{InsecureSkipVerify: true} } + + multiRcpts := []string{TestRcptValid, TestRcptValid, TestRcptValid} + _, err := QuickSend(addr, NewAuthData("username", "password"), sender, multiRcpts, subject, body) + if err != nil { + t.Fatalf("failed to send email: %s", err) + } + + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + + expects := []struct { + line int + data string + }{ + {8, "STARTTLS"}, + {17, "AUTH PLAIN AHVzZXJuYW1lAHBhc3N3b3Jk"}, + {21, "MAIL FROM: BODY=8BITMIME SMTPUTF8"}, + {23, "RCPT TO:"}, + {25, "RCPT TO:"}, + {27, "RCPT TO:"}, + {34, "Subject: " + subject}, + {37, "From: "}, + {38, "To: , , "}, + {39, "Content-Type: text/plain; charset=UTF-8"}, + {40, "Content-Transfer-Encoding: quoted-printable"}, + {42, "This is a test body"}, + {43, "With multiple lines"}, + {44, ""}, + {45, "Best,"}, + {46, " The go-mail team"}, + } + for _, expect := range expects { + if !strings.EqualFold(resp[expect.line], expect.data) { + t.Errorf("expected %q at line %d, got: %q", expect.data, expect.line, resp[expect.line]) + } + } + }) + t.Run("QuickSend uses stronged authentication method", func(t *testing.T) { + ctxAuth, cancelAuth := context.WithCancel(context.Background()) + defer cancelAuth() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN CRAM-MD5 SCRAM-SHA-256-PLUS SCRAM-SHA-256\r\n250-8BITMIME\r\n250-DSN\r\n250-STARTTLS\r\n250 SMTPUTF8" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctxAuth, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := TestServerAddr + ":" + fmt.Sprint(serverPort) + testHookTLSConfig = func() *tls.Config { return &tls.Config{InsecureSkipVerify: true} } + + _, err := QuickSend(addr, NewAuthData("username", "password"), sender, rcpts, subject, body) + if err != nil { + t.Fatalf("failed to send email: %s", err) + } + + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + + expects := []struct { + line int + data string + }{ + {17, "AUTH SCRAM-SHA-256-PLUS"}, + } + for _, expect := range expects { + if !strings.EqualFold(resp[expect.line], expect.data) { + t.Errorf("expected %q at line %d, got: %q", expect.data, expect.line, resp[expect.line]) + } + } + }) + t.Run("QuickSend uses stronged authentication method without TLS", func(t *testing.T) { + ctxAuth, cancelAuth := context.WithCancel(context.Background()) + defer cancelAuth() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN CRAM-MD5 SCRAM-SHA-256-PLUS SCRAM-SHA-256\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctxAuth, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := TestServerAddr + ":" + fmt.Sprint(serverPort) + testHookTLSConfig = func() *tls.Config { return &tls.Config{InsecureSkipVerify: true} } + + _, err := QuickSend(addr, NewAuthData("username", "password"), sender, rcpts, subject, body) + if err != nil { + t.Fatalf("failed to send email: %s", err) + } + + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + + expects := []struct { + line int + data string + }{ + {7, "AUTH SCRAM-SHA-256"}, + } + for _, expect := range expects { + if !strings.EqualFold(resp[expect.line], expect.data) { + t.Errorf("expected %q at line %d, got: %q", expect.data, expect.line, resp[expect.line]) + } + } + }) + t.Run("QuickSend fails during DialAndSned", func(t *testing.T) { + ctxAuth, cancelAuth := context.WithCancel(context.Background()) + defer cancelAuth() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN CRAM-MD5 SCRAM-SHA-256-PLUS SCRAM-SHA-256\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + props := &serverProps{ + FailOnMailFrom: true, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctxAuth, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := TestServerAddr + ":" + fmt.Sprint(serverPort) + testHookTLSConfig = func() *tls.Config { return &tls.Config{InsecureSkipVerify: true} } + + _, err := QuickSend(addr, NewAuthData("username", "password"), sender, rcpts, subject, body) + if err == nil { + t.Error("expected QuickSend to fail during DialAndSend") + } + expect := `failed to dial and send message: send failed: sending SMTP MAIL FROM command: 500 ` + + `5.5.2 Error: fail on MAIL FROM` + if !strings.EqualFold(err.Error(), expect) { + t.Errorf("expected error to contain %s, got %s", expect, err) + } + }) + t.Run("QuickSend fails on server address without port", func(t *testing.T) { + addr := TestServerAddr + _, err := QuickSend(addr, NewAuthData("username", "password"), sender, rcpts, subject, body) + if err == nil { + t.Error("expected QuickSend to fail with invalid server address") + } + expect := "failed to split host and port from address: address 127.0.0.1: missing port in address" + if !strings.Contains(err.Error(), expect) { + t.Errorf("expected error to contain %s, got %s", expect, err) + } + }) + t.Run("QuickSend fails on server address with invalid port", func(t *testing.T) { + addr := TestServerAddr + ":invalid" + _, err := QuickSend(addr, NewAuthData("username", "password"), sender, rcpts, subject, body) + if err == nil { + t.Error("expected QuickSend to fail with invalid server port") + } + expect := `failed to convert port to int: strconv.Atoi: parsing "invalid": invalid syntax` + if !strings.Contains(err.Error(), expect) { + t.Errorf("expected error to contain %s, got %s", expect, err) + } + }) + t.Run("QuickSend fails on nil TLS config (test hook only)", func(t *testing.T) { + addr := TestServerAddr + ":587" + testHookTLSConfig = func() *tls.Config { return nil } + defer func() { + testHookTLSConfig = nil + }() + _, err := QuickSend(addr, NewAuthData("username", "password"), sender, rcpts, subject, body) + if err == nil { + t.Error("expected QuickSend to fail with nil-tlsConfig") + } + expect := `failed to set TLS config: invalid TLS config` + if !strings.Contains(err.Error(), expect) { + t.Errorf("expected error to contain %s, got %s", expect, err) + } + }) + t.Run("QuickSend fails with invalid from address", func(t *testing.T) { + addr := TestServerAddr + ":587" + invalid := "invalid-fromdomain.tld" + _, err := QuickSend(addr, NewAuthData("username", "password"), invalid, rcpts, subject, body) + if err == nil { + t.Error("expected QuickSend to fail with invalid from address") + } + expect := `failed to set MAIL FROM address: failed to parse mail address "invalid-fromdomain.tld": ` + + `mail: missing '@' or angle-addr` + if !strings.Contains(err.Error(), expect) { + t.Errorf("expected error to contain %s, got %s", expect, err) + } + }) + t.Run("QuickSend fails with invalid from address", func(t *testing.T) { + addr := TestServerAddr + ":587" + invalid := []string{"invalid-todomain.tld"} + _, err := QuickSend(addr, NewAuthData("username", "password"), sender, invalid, subject, body) + if err == nil { + t.Error("expected QuickSend to fail with invalid to address") + } + expect := `failed to set RCPT TO address: failed to parse mail address "invalid-todomain.tld": ` + + `mail: missing '@' or angle-add` + if !strings.Contains(err.Error(), expect) { + t.Errorf("expected error to contain %s, got %s", expect, err) + } + }) +} diff --git a/random_test.go b/random_test.go index a608c2a..b59eba6 100644 --- a/random_test.go +++ b/random_test.go @@ -5,45 +5,81 @@ package mail import ( + "crypto/rand" + "errors" "strings" "testing" ) // TestRandomStringSecure tests the randomStringSecure method func TestRandomStringSecure(t *testing.T) { - tt := []struct { - testName string - length int - mustNotMatch string - }{ - {"20 chars", 20, "'"}, - {"100 chars", 100, "'"}, - {"1000 chars", 1000, "'"}, - } + t.Run("randomStringSecure with varying length", func(t *testing.T) { + tt := []struct { + testName string + length int + mustNotMatch string + }{ + {"20 chars", 20, "'"}, + {"100 chars", 100, "'"}, + {"1000 chars", 1000, "'"}, + } - for _, tc := range tt { - t.Run(tc.testName, func(t *testing.T) { - rs, err := randomStringSecure(tc.length) - if err != nil { - t.Errorf("random string generation failed: %s", err) - } - if strings.Contains(rs, tc.mustNotMatch) { - t.Errorf("random string contains unexpected character. got: %s, not-expected: %s", - rs, tc.mustNotMatch) - } - if len(rs) != tc.length { - t.Errorf("random string length does not match. expected: %d, got: %d", tc.length, len(rs)) - } - }) - } + for _, tc := range tt { + t.Run(tc.testName, func(t *testing.T) { + rs, err := randomStringSecure(tc.length) + if err != nil { + t.Errorf("random string generation failed: %s", err) + } + if strings.Contains(rs, tc.mustNotMatch) { + t.Errorf("random string contains unexpected character. got: %s, not-expected: %s", + rs, tc.mustNotMatch) + } + if len(rs) != tc.length { + t.Errorf("random string length does not match. expected: %d, got: %d", tc.length, len(rs)) + } + }) + } + }) + t.Run("randomStringSecure fails on broken rand Reader (first read)", func(t *testing.T) { + defaultRandReader := rand.Reader + t.Cleanup(func() { rand.Reader = defaultRandReader }) + rand.Reader = &randReader{failon: 1} + if _, err := randomStringSecure(22); err == nil { + t.Fatalf("expected failure on broken rand Reader") + } + }) + t.Run("randomStringSecure fails on broken rand Reader (second read)", func(t *testing.T) { + defaultRandReader := rand.Reader + t.Cleanup(func() { rand.Reader = defaultRandReader }) + rand.Reader = &randReader{failon: 0} + if _, err := randomStringSecure(22); err == nil { + t.Fatalf("expected failure on broken rand Reader") + } + }) } func BenchmarkGenerator_RandomStringSecure(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - _, err := randomStringSecure(22) + _, err := randomStringSecure(10) if err != nil { b.Errorf("RandomStringFromCharRange() failed: %s", err) } } } + +// randReader is type that satisfies the io.Reader interface. It can fail on a specific read +// operations and is therefore useful to test consecutive reads with errors +type randReader struct { + failon uint8 + call uint8 +} + +// Read implements the io.Reader interface for the randReader type +func (r *randReader) Read(p []byte) (int, error) { + if r.call == r.failon { + r.call++ + return len(p), nil + } + return 0, errors.New("broken reader") +} diff --git a/senderror.go b/senderror.go index 1943e28..e32d74b 100644 --- a/senderror.go +++ b/senderror.go @@ -6,6 +6,8 @@ package mail import ( "errors" + "regexp" + "strconv" "strings" ) @@ -60,11 +62,13 @@ const ( // details about the affected message, a list of errors, the recipient list, and whether // the error is temporary or permanent. It also includes a reason code for the error. type SendError struct { - affectedMsg *Msg - errlist []error - isTemp bool - rcpt []string - Reason SendErrReason + affectedMsg *Msg + errcode int + enhancedStatusCode string + errlist []error + isTemp bool + rcpt []string + Reason SendErrReason } // SendErrReason represents a comparable reason on why the delivery failed @@ -81,7 +85,7 @@ type SendErrReason int // Returns: // - A string representing the error message. func (e *SendError) Error() string { - if e.Reason > 10 { + if e.Reason > ErrAmbiguous { return "unknown reason" } @@ -93,7 +97,7 @@ func (e *SendError) Error() string { errMessage.WriteRune(' ') errMessage.WriteString(e.errlist[i].Error()) if i != len(e.errlist)-1 { - errMessage.WriteString(", ") + errMessage.WriteString(",") } } } @@ -175,6 +179,42 @@ func (e *SendError) Msg() *Msg { return e.affectedMsg } +// EnhancedStatusCode returns the enhanced status code of the server response if the +// server supports it, as described in RFC 2034. +// +// This function retrieves the enhanced status code of an error returned by the server. This +// requires that the receiving server supports this SMTP extension as described in RFC 2034. +// Since this is the SendError interface, we only collect status codes for error responses, +// meaning 4xx or 5xx. If the server does not support the ENHANCEDSTATUSCODES extension or +// the error did not include an enhanced status code, it will return an empty string. +// +// Returns: +// - The enhanced status code as returned by the server, or an empty string is not supported. +// +// References: +// - https://datatracker.ietf.org/doc/html/rfc2034 +func (e *SendError) EnhancedStatusCode() string { + if e == nil { + return "" + } + return e.enhancedStatusCode +} + +// ErrorCode returns the error code of the server response. +// +// This function retrieves the error code the error returned by the server. The error code will +// start with 5 on permanent errors and with 4 on a temporary error. If the error is not returned +// by the server, but is generated by go-mail, the code will be 0. +// +// Returns: +// - The error code as returned by the server, or 0 if not a server error. +func (e *SendError) ErrorCode() int { + if e == nil { + return 0 + } + return e.errcode +} + // String satisfies the fmt.Stringer interface for the SendErrReason type. // // This function converts the SendErrReason into a human-readable string representation based @@ -224,3 +264,39 @@ func (r SendErrReason) String() string { func isTempError(err error) bool { return err.Error()[0] == '4' } + +func errorCode(err error) int { + rootErr := errors.Unwrap(err) + if rootErr != nil { + err = rootErr + } + firstrune := err.Error()[0] + if firstrune < 52 || firstrune > 53 { + return 0 + } + code := err.Error()[0:3] + errcode, cerr := strconv.Atoi(code) + if cerr != nil { + return 0 + } + return errcode +} + +func enhancedStatusCode(err error, supported bool) string { + if err == nil || !supported { + return "" + } + rootErr := errors.Unwrap(err) + if rootErr != nil { + err = rootErr + } + firstrune := err.Error()[0] + if firstrune != 50 && firstrune != 52 && firstrune != 53 { + return "" + } + re, rerr := regexp.Compile(`\b([245])\.\d{1,3}\.\d{1,3}\b`) + if rerr != nil { + return "" + } + return re.FindString(err.Error()) +} diff --git a/senderror_test.go b/senderror_test.go index e04b7ee..63d4c73 100644 --- a/senderror_test.go +++ b/senderror_test.go @@ -13,156 +13,354 @@ import ( // TestSendError_Error tests the SendError and SendErrReason error handling methods func TestSendError_Error(t *testing.T) { - tl := []struct { - n string - r SendErrReason - te bool - }{ - {"ErrGetSender/temp", ErrGetSender, true}, - {"ErrGetSender/perm", ErrGetSender, false}, - {"ErrGetRcpts/temp", ErrGetRcpts, true}, - {"ErrGetRcpts/perm", ErrGetRcpts, false}, - {"ErrSMTPMailFrom/temp", ErrSMTPMailFrom, true}, - {"ErrSMTPMailFrom/perm", ErrSMTPMailFrom, false}, - {"ErrSMTPRcptTo/temp", ErrSMTPRcptTo, true}, - {"ErrSMTPRcptTo/perm", ErrSMTPRcptTo, false}, - {"ErrSMTPData/temp", ErrSMTPData, true}, - {"ErrSMTPData/perm", ErrSMTPData, false}, - {"ErrSMTPDataClose/temp", ErrSMTPDataClose, true}, - {"ErrSMTPDataClose/perm", ErrSMTPDataClose, false}, - {"ErrSMTPReset/temp", ErrSMTPReset, true}, - {"ErrSMTPReset/perm", ErrSMTPReset, false}, - {"ErrWriteContent/temp", ErrWriteContent, true}, - {"ErrWriteContent/perm", ErrWriteContent, false}, - {"ErrConnCheck/temp", ErrConnCheck, true}, - {"ErrConnCheck/perm", ErrConnCheck, false}, - {"ErrNoUnencoded/temp", ErrNoUnencoded, true}, - {"ErrNoUnencoded/perm", ErrNoUnencoded, false}, - {"ErrAmbiguous/temp", ErrAmbiguous, true}, - {"ErrAmbiguous/perm", ErrAmbiguous, false}, - {"Unknown/temp", 9999, true}, - {"Unknown/perm", 9999, false}, - } - - for _, tt := range tl { - t.Run(tt.n, func(t *testing.T) { - if err := returnSendError(tt.r, tt.te); err != nil { - exp := &SendError{Reason: tt.r, isTemp: tt.te} - if !errors.Is(err, exp) { - t.Errorf("error mismatch, expected: %s (temp: %t), got: %s (temp: %t)", tt.r, tt.te, - exp.Error(), exp.isTemp) + t.Run("TestSendError_Error with various reasons", func(t *testing.T) { + tests := []struct { + name string + reason SendErrReason + isTemp bool + }{ + {"ErrGetSender/temp", ErrGetSender, true}, + {"ErrGetSender/perm", ErrGetSender, false}, + {"ErrGetRcpts/temp", ErrGetRcpts, true}, + {"ErrGetRcpts/perm", ErrGetRcpts, false}, + {"ErrSMTPMailFrom/temp", ErrSMTPMailFrom, true}, + {"ErrSMTPMailFrom/perm", ErrSMTPMailFrom, false}, + {"ErrSMTPRcptTo/temp", ErrSMTPRcptTo, true}, + {"ErrSMTPRcptTo/perm", ErrSMTPRcptTo, false}, + {"ErrSMTPData/temp", ErrSMTPData, true}, + {"ErrSMTPData/perm", ErrSMTPData, false}, + {"ErrSMTPDataClose/temp", ErrSMTPDataClose, true}, + {"ErrSMTPDataClose/perm", ErrSMTPDataClose, false}, + {"ErrSMTPReset/temp", ErrSMTPReset, true}, + {"ErrSMTPReset/perm", ErrSMTPReset, false}, + {"ErrWriteContent/temp", ErrWriteContent, true}, + {"ErrWriteContent/perm", ErrWriteContent, false}, + {"ErrConnCheck/temp", ErrConnCheck, true}, + {"ErrConnCheck/perm", ErrConnCheck, false}, + {"ErrNoUnencoded/temp", ErrNoUnencoded, true}, + {"ErrNoUnencoded/perm", ErrNoUnencoded, false}, + {"ErrAmbiguous/temp", ErrAmbiguous, true}, + {"ErrAmbiguous/perm", ErrAmbiguous, false}, + {"Unknown/temp", 9999, true}, + {"Unknown/perm", 9999, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := returnSendError(tt.reason, tt.isTemp) + if err == nil { + t.Fatalf("error expected, got nil") } - if !strings.Contains(fmt.Sprintf("%s", err), tt.r.String()) { + want := &SendError{Reason: tt.reason, isTemp: tt.isTemp} + if !errors.Is(err, want) { + t.Errorf("error mismatch, expected: %s (temp: %t), got: %s (temp: %t)", + tt.reason, tt.isTemp, want.Error(), want.isTemp) + } + if !strings.Contains(err.Error(), tt.reason.String()) { t.Errorf("error string mismatch, expected: %s, got: %s", - tt.r.String(), fmt.Sprintf("%s", err)) + tt.reason.String(), err.Error()) } - } - }) - } + }) + } + }) + t.Run("TestSendError_Error with multiple errors", func(t *testing.T) { + message := testMessage(t) + err := &SendError{ + affectedMsg: message, + errlist: []error{ErrNoRcptAddresses, ErrNoFromAddress}, + rcpt: []string{"", ""}, + Reason: ErrAmbiguous, + } + if !strings.Contains(err.Error(), "ambiguous reason, check Msg.SendError for message specific reasons") { + t.Errorf("error string mismatch, expected: ambiguous reason, check Msg.SendError for message "+ + "specific reasons, got: %s", err.Error()) + } + if !strings.Contains(err.Error(), "no recipient addresses set, no FROM address set") { + t.Errorf("error string mismatch, expected: no recipient addresses set, no FROM address set, got: %s", + err.Error()) + } + if !strings.Contains(err.Error(), "affected recipient(s): , "+ + "") { + t.Errorf("error string mismatch, expected: affected recipient(s): , "+ + ", got: %s", err.Error()) + } + }) +} + +func TestSendError_Is(t *testing.T) { + t.Run("TestSendError_Is errors match", func(t *testing.T) { + err1 := returnSendError(ErrAmbiguous, false) + err2 := returnSendError(ErrAmbiguous, false) + if !errors.Is(err1, err2) { + t.Error("error mismatch, expected ErrAmbiguous to be equal to ErrAmbiguous") + } + }) + t.Run("TestSendError_Is errors mismatch", func(t *testing.T) { + err1 := returnSendError(ErrAmbiguous, false) + err2 := returnSendError(ErrSMTPMailFrom, false) + if errors.Is(err1, err2) { + t.Error("error mismatch, ErrAmbiguous should not be equal to ErrSMTPMailFrom") + } + }) + t.Run("TestSendError_Is on nil", func(t *testing.T) { + var err *SendError + if err.Is(ErrNoFromAddress) { + t.Error("expected false on nil-senderror") + } + }) } func TestSendError_IsTemp(t *testing.T) { - var se *SendError - err1 := returnSendError(ErrAmbiguous, true) - if !errors.As(err1, &se) { - t.Errorf("error mismatch, expected error to be of type *SendError") - return - } - if errors.As(err1, &se) && !se.IsTemp() { - t.Errorf("error mismatch, expected temporary error") - return - } - err2 := returnSendError(ErrAmbiguous, false) - if !errors.As(err2, &se) { - t.Errorf("error mismatch, expected error to be of type *SendError") - return - } - if errors.As(err2, &se) && se.IsTemp() { - t.Errorf("error mismatch, expected non-temporary error") - return - } -} - -func TestSendError_IsTempNil(t *testing.T) { - var se *SendError - if se.IsTemp() { - t.Error("expected false on nil-senderror") - } + t.Run("TestSendError_IsTemp is true", func(t *testing.T) { + err := returnSendError(ErrAmbiguous, true) + if err == nil { + t.Fatalf("error expected, got nil") + } + var sendErr *SendError + if !errors.As(err, &sendErr) { + t.Fatal("error expected to be of type *SendError") + } + if !sendErr.IsTemp() { + t.Errorf("expected temporary error, got: temperr: %t", sendErr.IsTemp()) + } + }) + t.Run("TestSendError_IsTemp is false", func(t *testing.T) { + err := returnSendError(ErrAmbiguous, false) + if err == nil { + t.Fatalf("error expected, got nil") + } + var sendErr *SendError + if !errors.As(err, &sendErr) { + t.Fatal("error expected to be of type *SendError") + } + if sendErr.IsTemp() { + t.Errorf("expected permanent error, got: temperr: %t", sendErr.IsTemp()) + } + }) + t.Run("TestSendError_IsTemp is nil", func(t *testing.T) { + var se *SendError + if se.IsTemp() { + t.Error("expected false on nil-senderror") + } + }) } func TestSendError_MessageID(t *testing.T) { - var se *SendError - err := returnSendError(ErrAmbiguous, false) - if !errors.As(err, &se) { - t.Errorf("error mismatch, expected error to be of type *SendError") - return - } - if errors.As(err, &se) { - if se.MessageID() == "" { - t.Errorf("sendError expected message-id, but got empty string") + t.Run("TestSendError_MessageID message ID is set", func(t *testing.T) { + var sendErr *SendError + err := returnSendError(ErrAmbiguous, false) + if !errors.As(err, &sendErr) { + t.Fatal("error mismatch, expected error to be of type *SendError") } - if !strings.EqualFold(se.MessageID(), "") { + if sendErr.MessageID() == "" { + t.Error("sendError expected message-id, but got empty string") + } + if !strings.EqualFold(sendErr.MessageID(), "") { t.Errorf("sendError message-id expected: %s, but got: %s", "", - se.MessageID()) + sendErr.MessageID()) } - } -} - -func TestSendError_MessageIDNil(t *testing.T) { - var se *SendError - if se.MessageID() != "" { - t.Error("expected empty string on nil-senderror") - } + }) + t.Run("TestSendError_MessageID message ID is not set", func(t *testing.T) { + var sendErr *SendError + message := testMessage(t) + err := &SendError{ + affectedMsg: message, + errlist: []error{ErrNoRcptAddresses}, + rcpt: []string{"", ""}, + Reason: ErrAmbiguous, + } + if !errors.As(err, &sendErr) { + t.Fatal("error mismatch, expected error to be of type *SendError") + } + if sendErr.MessageID() != "" { + t.Errorf("sendError expected empty message-id, got: %s", sendErr.MessageID()) + } + }) + t.Run("TestSendError_MessageID on nil error should return empty", func(t *testing.T) { + var sendErr *SendError + if sendErr.MessageID() != "" { + t.Error("expected empty message-id on nil-senderror") + } + }) } func TestSendError_Msg(t *testing.T) { - var se *SendError - err := returnSendError(ErrAmbiguous, false) - if !errors.As(err, &se) { - t.Errorf("error mismatch, expected error to be of type *SendError") - return - } - if errors.As(err, &se) { - if se.Msg() == nil { - t.Errorf("sendError expected msg pointer, but got nil") + t.Run("TestSendError_Msg message is set", func(t *testing.T) { + var sendErr *SendError + err := returnSendError(ErrAmbiguous, false) + if !errors.As(err, &sendErr) { + t.Fatal("error mismatch, expected error to be of type *SendError") } - from := se.Msg().GetFromString() + msg := sendErr.Msg() + if msg == nil { + t.Fatalf("sendError expected msg pointer, but got nil") + } + from := msg.GetFromString() if len(from) == 0 { - t.Errorf("sendError expected msg from, but got empty string") - return + t.Fatal("sendError expected msg from, but got empty string") } if !strings.EqualFold(from[0], "") { t.Errorf("sendError message from expected: %s, but got: %s", "", from[0]) } - } + }) + t.Run("TestSendError_Msg message is not set", func(t *testing.T) { + var sendErr *SendError + err := &SendError{ + errlist: []error{ErrNoRcptAddresses}, + rcpt: []string{"", ""}, + Reason: ErrAmbiguous, + } + if !errors.As(err, &sendErr) { + t.Fatal("error mismatch, expected error to be of type *SendError") + } + if sendErr.Msg() != nil { + t.Errorf("sendError expected nil msg pointer, got: %v", sendErr.Msg()) + } + }) } -func TestSendError_MsgNil(t *testing.T) { - var se *SendError - if se.Msg() != nil { - t.Error("expected nil on nil-senderror") - } +func TestSendError_EnhancedStatusCode(t *testing.T) { + t.Run("SendError with no enhanced status code", func(t *testing.T) { + err := &SendError{ + errlist: []error{ErrNoRcptAddresses}, + rcpt: []string{"", ""}, + Reason: ErrAmbiguous, + } + if err.EnhancedStatusCode() != "" { + t.Errorf("expected empty enhanced status code, got: %s", err.EnhancedStatusCode()) + } + }) + t.Run("SendError with enhanced status code", func(t *testing.T) { + err := &SendError{ + errlist: []error{ErrNoRcptAddresses}, + rcpt: []string{"", ""}, + Reason: ErrAmbiguous, + enhancedStatusCode: "5.7.1", + } + if err.EnhancedStatusCode() != "5.7.1" { + t.Errorf("expected enhanced status code: %s, got: %s", "5.7.1", err.EnhancedStatusCode()) + } + }) + t.Run("enhanced status code on nil error should return empty string", func(t *testing.T) { + var err *SendError + if err.EnhancedStatusCode() != "" { + t.Error("expected empty enhanced status code on nil-senderror") + } + }) } -func TestSendError_IsFail(t *testing.T) { - err1 := returnSendError(ErrAmbiguous, false) - err2 := returnSendError(ErrSMTPMailFrom, false) - if errors.Is(err1, err2) { - t.Errorf("error mismatch, ErrAmbiguous should not be equal to ErrSMTPMailFrom") - } +func TestSendError_ErrorCode(t *testing.T) { + t.Run("ErrorCode with a go-mail error should return 0", func(t *testing.T) { + err := &SendError{ + errlist: []error{ErrNoRcptAddresses}, + rcpt: []string{"", ""}, + Reason: ErrAmbiguous, + errcode: errorCode(ErrNoRcptAddresses), + } + if err.ErrorCode() != 0 { + t.Errorf("expected error code: %d, got: %d", 0, err.ErrorCode()) + } + }) + t.Run("SendError with permanent error", func(t *testing.T) { + err := &SendError{ + errlist: []error{ErrNoRcptAddresses}, + rcpt: []string{"", ""}, + Reason: ErrAmbiguous, + errcode: errorCode(errors.New("535 5.7.8 Error: authentication failed")), + } + if err.ErrorCode() != 535 { + t.Errorf("expected error code: %d, got: %d", 535, err.ErrorCode()) + } + }) + t.Run("SendError with temporary error", func(t *testing.T) { + err := &SendError{ + errlist: []error{ErrNoRcptAddresses}, + rcpt: []string{"", ""}, + Reason: ErrAmbiguous, + errcode: errorCode(errors.New("441 4.1.0 Server currently unavailable")), + } + if err.ErrorCode() != 441 { + t.Errorf("expected error code: %d, got: %d", 441, err.ErrorCode()) + } + }) + t.Run("error code on nil error should return 0", func(t *testing.T) { + var err *SendError + if err.ErrorCode() != 0 { + t.Error("expected 0 error code on nil-senderror") + } + }) } -func TestSendError_ErrorMulti(t *testing.T) { - expected := `ambiguous reason, check Msg.SendError for message specific reasons, ` + - `affected recipient(s): , ` - err := &SendError{ - Reason: ErrAmbiguous, isTemp: false, affectedMsg: nil, - rcpt: []string{"", ""}, - } - if err.Error() != expected { - t.Errorf("error mismatch, expected: %s, got: %s", expected, err.Error()) - } +func TestSendError_errorCode(t *testing.T) { + t.Run("errorCode with a go-mail error should return 0", func(t *testing.T) { + code := errorCode(ErrNoRcptAddresses) + if code != 0 { + t.Errorf("expected error code: %d, got: %d", 0, code) + } + }) + t.Run("errorCode with permanent error", func(t *testing.T) { + code := errorCode(errors.New("535 5.7.8 Error: authentication failed")) + if code != 535 { + t.Errorf("expected error code: %d, got: %d", 535, code) + } + }) + t.Run("errorCode with temporary error", func(t *testing.T) { + code := errorCode(errors.New("443 4.1.0 Server currently unavailable")) + if code != 443 { + t.Errorf("expected error code: %d, got: %d", 443, code) + } + }) + t.Run("errorCode with wrapper error", func(t *testing.T) { + code := errorCode(fmt.Errorf("an error occured: %w", errors.New("443 4.1.0 Server currently unavailable"))) + if code != 443 { + t.Errorf("expected error code: %d, got: %d", 443, code) + } + }) + t.Run("errorCode with non-4xx and non-5xx error", func(t *testing.T) { + code := errorCode(errors.New("220 2.1.0 This is not an error")) + if code != 0 { + t.Errorf("expected error code: %d, got: %d", 0, code) + } + }) + t.Run("errorCode with non 3-digit code", func(t *testing.T) { + code := errorCode(errors.New("4xx 4.1.0 The status code is invalid")) + if code != 0 { + t.Errorf("expected error code: %d, got: %d", 0, code) + } + }) +} + +func TestSendError_enhancedStatusCode(t *testing.T) { + t.Run("enhancedStatusCode with nil error should return empty string", func(t *testing.T) { + code := enhancedStatusCode(nil, true) + if code != "" { + t.Errorf("expected empty enhanced status code, got: %s", code) + } + }) + t.Run("enhancedStatusCode with error but no support should return empty string", func(t *testing.T) { + code := enhancedStatusCode(errors.New("553 5.5.3 something went wrong"), false) + if code != "" { + t.Errorf("expected empty enhanced status code, got: %s", code) + } + }) + t.Run("enhancedStatusCode with error and support", func(t *testing.T) { + code := enhancedStatusCode(errors.New("553 5.5.3 something went wrong"), true) + if code != "5.5.3" { + t.Errorf("expected enhanced status code: %s, got: %s", "5.5.3", code) + } + }) + t.Run("enhancedStatusCode with wrapped error and support", func(t *testing.T) { + code := enhancedStatusCode(fmt.Errorf("this error is wrapped: %w", errors.New("553 5.5.3 something went wrong")), true) + if code != "5.5.3" { + t.Errorf("expected enhanced status code: %s, got: %s", "5.5.3", code) + } + }) + t.Run("enhancedStatusCode with 3xx error", func(t *testing.T) { + code := enhancedStatusCode(errors.New("300 3.0.0 i don't know what i'm doing"), true) + if code != "" { + t.Errorf("expected enhanced status code to be empty, got: %s", code) + } + }) } // returnSendError is a helper method to retunr a SendError with a specific reason @@ -173,6 +371,5 @@ func returnSendError(r SendErrReason, t bool) error { message.Subject("This is the subject") message.SetBodyString(TypeTextPlain, "This is the message body") message.SetMessageIDWithValue("this.is.a.message.id") - return &SendError{Reason: r, isTemp: t, affectedMsg: message} } diff --git a/smtp/auth_login.go b/smtp/auth_login.go index b5f1065..c40e48c 100644 --- a/smtp/auth_login.go +++ b/smtp/auth_login.go @@ -36,8 +36,8 @@ type loginAuth struct { // LoginAuth will only send the credentials if the connection is using TLS // or is connected to localhost. Otherwise authentication will fail with an // error, without sending the credentials. -func LoginAuth(username, password, host string, allowUnEnc bool) Auth { - return &loginAuth{username, password, host, 0, allowUnEnc} +func LoginAuth(username, password, host string, allowUnenc bool) Auth { + return &loginAuth{username, password, host, 0, allowUnenc} } // Start begins the SMTP authentication process by validating server's TLS status and hostname. diff --git a/smtp/auth_plain.go b/smtp/auth_plain.go index f2ea8ac..39e2a2f 100644 --- a/smtp/auth_plain.go +++ b/smtp/auth_plain.go @@ -28,8 +28,8 @@ type plainAuth struct { // PlainAuth will only send the credentials if the connection is using TLS // or is connected to localhost. Otherwise authentication will fail with an // error, without sending the credentials. -func PlainAuth(identity, username, password, host string, allowUnEnc bool) Auth { - return &plainAuth{identity, username, password, host, allowUnEnc} +func PlainAuth(identity, username, password, host string, allowUnenc bool) Auth { + return &plainAuth{identity, username, password, host, allowUnenc} } func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { diff --git a/smtp/auth_scram.go b/smtp/auth_scram.go index a21aef5..03de54c 100644 --- a/smtp/auth_scram.go +++ b/smtp/auth_scram.go @@ -154,7 +154,7 @@ func (a *scramAuth) initialClientMessage() ([]byte, error) { connState := a.tlsConnState bindData := connState.TLSUnique - // crypto/tl: no tls-unique channel binding value for this tls connection, possibly due to missing + // crypto/tls: no tls-unique channel binding value for this tls connection, possibly due to missing // extended master key support and/or resumed connection // RFC9266:122 tls-unique not defined for tls 1.3 and later if bindData == nil || connState.Version >= tls.VersionTLS13 { @@ -308,10 +308,7 @@ func (a *scramAuth) normalizeUsername() (string, error) { func (a *scramAuth) normalizeString(s string) (string, error) { s, err := precis.OpaqueString.String(s) if err != nil { - return "", fmt.Errorf("failled to normalize string: %w", err) - } - if s == "" { - return "", errors.New("normalized string is empty") + return "", fmt.Errorf("failed to normalize string: %w", err) } return s, nil } diff --git a/smtp/smtp.go b/smtp/smtp.go index 4841ec8..24a079f 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -587,7 +587,9 @@ func (c *Client) SetLogger(l log.Logger) { if l == nil { return } + c.mutex.Lock() c.logger = l + c.mutex.Unlock() } // SetLogAuthData enables logging of authentication data in the Client. @@ -599,12 +601,16 @@ func (c *Client) SetLogAuthData() { // SetDSNMailReturnOption sets the DSN mail return option for the Mail method func (c *Client) SetDSNMailReturnOption(d string) { + c.mutex.Lock() c.dsnmrtype = d + c.mutex.Unlock() } // SetDSNRcptNotifyOption sets the DSN recipient notify option for the Mail method func (c *Client) SetDSNRcptNotifyOption(d string) { + c.mutex.Lock() c.dsnrntype = d + c.mutex.Unlock() } // HasConnection checks if the client has an active connection. @@ -620,6 +626,9 @@ func (c *Client) HasConnection() bool { func (c *Client) UpdateDeadline(timeout time.Duration) error { c.mutex.Lock() defer c.mutex.Unlock() + if c.conn == nil { + return errors.New("smtp: client has no connection") + } if err := c.conn.SetDeadline(time.Now().Add(timeout)); err != nil { return fmt.Errorf("smtp: failed to update deadline: %w", err) } diff --git a/smtp/smtp_121_test.go b/smtp/smtp_121_test.go new file mode 100644 index 0000000..ed722e0 --- /dev/null +++ b/smtp/smtp_121_test.go @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: Copyright 2010 The Go Authors. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2022-2023 The go-mail Authors +// +// Original net/smtp code from the Go stdlib by the Go Authors. +// Use of this source code is governed by a BSD-style +// LICENSE file that can be found in this directory. +// +// go-mail specific modifications by the go-mail Authors. +// Licensed under the MIT License. +// See [PROJECT ROOT]/LICENSES directory for more information. +// +// SPDX-License-Identifier: BSD-3-Clause AND MIT + +//go:build go1.21 +// +build go1.21 + +package smtp + +import ( + "fmt" + "os" + "strings" + "testing" + + "github.com/wneessen/go-mail/log" +) + +func TestClient_SetDebugLog_JSON(t *testing.T) { + t.Run("set debug loggging to on should not override logger", func(t *testing.T) { + client := &Client{logger: log.NewJSON(os.Stderr, log.LevelDebug)} + client.SetDebugLog(true) + if !client.debug { + t.Fatalf("expected debug log to be true") + } + if client.logger == nil { + t.Fatalf("expected logger to be defined") + } + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { + t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) + } + }) +} + +func TestClient_SetLogger_JSON(t *testing.T) { + t.Run("set logger to JSONlog logger", func(t *testing.T) { + client := &Client{} + client.SetLogger(log.NewJSON(os.Stderr, log.LevelDebug)) + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { + t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) + } + }) + t.Run("nil logger should just return and not set/override", func(t *testing.T) { + client := &Client{logger: log.NewJSON(os.Stderr, log.LevelDebug)} + client.SetLogger(nil) + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.JSONlog") { + t.Errorf("expected logger to be of type *log.JSONlog, got: %T", client.logger) + } + }) +} diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 4fe0481..471b1e1 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -16,21 +16,24 @@ package smtp import ( "bufio" "bytes" + "context" "crypto/hmac" + "crypto/rand" "crypto/sha1" "crypto/sha256" "crypto/tls" "crypto/x509" "encoding/base64" - "flag" + "errors" "fmt" "hash" "io" "net" - "net/textproto" "os" - "runtime" + "strconv" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -39,6 +42,57 @@ import ( "github.com/wneessen/go-mail/log" ) +const ( + // TestServerProto is the protocol used for the simple SMTP test server + TestServerProto = "tcp" + // TestServerAddr is the address the simple SMTP test server listens on + TestServerAddr = "127.0.0.1" +) + +// PortAdder is an atomic counter used to increment port numbers for the test SMTP server instances. +var PortAdder atomic.Int32 + +// TestServerPortBase is the base port for the simple SMTP test server +var TestServerPortBase int32 = 20025 + +// localhostCert is a PEM-encoded TLS cert generated from src/crypto/tls: +// +// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com \ +// --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +var localhostCert = []byte(` +-----BEGIN CERTIFICATE----- +MIICFDCCAX2gAwIBAgIRAK0xjnaPuNDSreeXb+z+0u4wDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2 +MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw +gYkCgYEA0nFbQQuOWsjbGtejcpWz153OlziZM4bVjJ9jYruNw5n2Ry6uYQAffhqa +JOInCmmcVe2siJglsyH9aRh6vKiobBbIUXXUU1ABd56ebAzlt0LobLlx7pZEMy30 +LqIi9E6zmL3YvdGzpYlkFRnRrqwEtWYbGBf3znO250S56CCWH2UCAwEAAaNoMGYw +DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF +MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA +AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAbZtDS2dVuBYvb+MnolWnCNqvw1w5Gtgi +NmvQQPOMgM3m+oQSCPRTNGSg25e1Qbo7bgQDv8ZTnq8FgOJ/rbkyERw2JckkHpD4 +n4qcK27WkEDBtQFlPihIM8hLIuzWoi/9wygiElTy/tVL3y7fGCvY2/k1KBthtZGF +tN8URjVmyEo= +-----END CERTIFICATE-----`) + +// localhostKey is the private key for localhostCert. +var localhostKey = []byte(testingKey(` +-----BEGIN RSA TESTING KEY----- +MIICXgIBAAKBgQDScVtBC45ayNsa16NylbPXnc6XOJkzhtWMn2Niu43DmfZHLq5h +AB9+Gpok4icKaZxV7ayImCWzIf1pGHq8qKhsFshRddRTUAF3np5sDOW3QuhsuXHu +lkQzLfQuoiL0TrOYvdi90bOliWQVGdGurAS1ZhsYF/fOc7bnRLnoIJYfZQIDAQAB +AoGBAMst7OgpKyFV6c3JwyI/jWqxDySL3caU+RuTTBaodKAUx2ZEmNJIlx9eudLA +kucHvoxsM/eRxlxkhdFxdBcwU6J+zqooTnhu/FE3jhrT1lPrbhfGhyKnUrB0KKMM +VY3IQZyiehpxaeXAwoAou6TbWoTpl9t8ImAqAMY8hlULCUqlAkEA+9+Ry5FSYK/m +542LujIcCaIGoG1/Te6Sxr3hsPagKC2rH20rDLqXwEedSFOpSS0vpzlPAzy/6Rbb +PHTJUhNdwwJBANXkA+TkMdbJI5do9/mn//U0LfrCR9NkcoYohxfKz8JuhgRQxzF2 +6jpo3q7CdTuuRixLWVfeJzcrAyNrVcBq87cCQFkTCtOMNC7fZnCTPUv+9q1tcJyB +vNjJu3yvoEZeIeuzouX9TJE21/33FaeDdsXbRhQEj23cqR38qFHsF1qAYNMCQQDP +QXLEiJoClkR2orAmqjPLVhR3t2oB3INcnEjLNSq8LHyQEfXyaFfu4U9l5+fRPL2i +jiC0k/9L5dHUsF0XZothAkEA23ddgRs+Id/HxtojqqUT27B8MT/IGNrYsp4DvS/c +qgkeluku4GjxRlDMBuXk94xOBEinUs+p/hwP1Alll80Tpg== +-----END RSA TESTING KEY-----`)) + type authTest struct { auth Auth challenges []string @@ -179,2206 +233,3660 @@ var authTests = []authTest{ }, } +func init() { + testPort := os.Getenv("TEST_BASEPORT_SMTP") + if testPort == "" { + return + } + if port, err := strconv.Atoi(testPort); err == nil { + if port <= 65000 && port > 1023 { + TestServerPortBase = int32(port) + } + } +} + func TestAuth(t *testing.T) { -testLoop: - for i, test := range authTests { - name, resp, err := test.auth.Start(&ServerInfo{"testserver", true, nil}) - if name != test.name { - t.Errorf("#%d got name %s, expected %s", i, name, test.name) - } - if !bytes.Equal(resp, []byte(test.responses[0])) { - t.Errorf("#%d got response %s, expected %s", i, resp, test.responses[0]) - } - if err != nil { - t.Errorf("#%d error: %s", i, err) - } - for j := range test.challenges { - challenge := []byte(test.challenges[j]) - expected := []byte(test.responses[j+1]) - sf := test.sf[j] - resp, err := test.auth.Next(challenge, true) - if err != nil && !sf { - t.Errorf("#%d error: %s", i, err) - continue testLoop - } - if test.hasNonce { - if !bytes.HasPrefix(resp, expected) { - t.Errorf("#%d got response: %s, expected response to start with: %s", i, resp, expected) + t.Run("Auth for all supported auth methods", func(t *testing.T) { + for i, tt := range authTests { + t.Run(tt.name, func(t *testing.T) { + name, resp, err := tt.auth.Start(&ServerInfo{"testserver", true, nil}) + if name != tt.name { + t.Errorf("test #%d got name %s, expected %s", i, name, tt.name) } - continue testLoop + if len(tt.responses) <= 0 { + t.Fatalf("test #%d got no responses, expected at least one", i) + } + if !bytes.Equal(resp, []byte(tt.responses[0])) { + t.Errorf("#%d got response %s, expected %s", i, resp, tt.responses[0]) + } + if err != nil { + t.Errorf("#%d error: %s", i, err) + } + testLoop: + for j := range tt.challenges { + challenge := []byte(tt.challenges[j]) + expected := []byte(tt.responses[j+1]) + sf := tt.sf[j] + resp, err := tt.auth.Next(challenge, true) + if err != nil && !sf { + t.Errorf("#%d error: %s", i, err) + continue testLoop + } + if tt.hasNonce { + if !bytes.HasPrefix(resp, expected) { + t.Errorf("#%d got response: %s, expected response to start with: %s", i, resp, expected) + } + continue testLoop + } + if !bytes.Equal(resp, expected) { + t.Errorf("#%d got %s, expected %s", i, resp, expected) + continue testLoop + } + _, err = tt.auth.Next([]byte("2.7.0 Authentication successful"), false) + if err != nil { + t.Errorf("#%d success message error: %s", i, err) + } + } + }) + } + }) +} + +func TestPlainAuth(t *testing.T) { + tests := []struct { + name string + authName string + server *ServerInfo + shouldFail bool + wantErr error + }{ + { + name: "PLAIN auth succeeds", + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + shouldFail: false, + }, + { + // OK to use PlainAuth on localhost without TLS + name: "PLAIN on localhost is allowed to go unencrypted", + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + shouldFail: false, + }, + { + // NOT OK on non-localhost, even if server says PLAIN is OK. + // (We don't know that the server is the real server.) + name: "PLAIN on non-localhost is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + shouldFail: true, + wantErr: ErrUnencrypted, + }, + { + name: "PLAIN on non-localhost with no PLAIN announcement, is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + shouldFail: true, + wantErr: ErrUnencrypted, + }, + { + name: "PLAIN with wrong hostname", + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + shouldFail: true, + wantErr: ErrWrongHostname, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + identity := "foo" + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + auth := PlainAuth(identity, user, pass, tt.authName, false) + method, resp, err := auth.Start(tt.server) + if err != nil && !tt.shouldFail { + t.Errorf("plain authentication failed: %s", err) } - if !bytes.Equal(resp, expected) { - t.Errorf("#%d got %s, expected %s", i, resp, expected) - continue testLoop + if err == nil && tt.shouldFail { + t.Error("plain authentication was expected to fail") } - _, err = test.auth.Next([]byte("2.7.0 Authentication successful"), false) + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err) + } + return + } + if method != "PLAIN" { + t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method) + } + if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) { + t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp) + } + }) + } + t.Run("PLAIN sends second server response should fail", func(t *testing.T) { + identity := "foo" + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + server := &ServerInfo{Name: "servername", TLS: true} + auth := PlainAuth(identity, user, pass, "servername", false) + method, resp, err := auth.Start(server) + if err != nil { + t.Fatalf("plain authentication failed: %s", err) + } + if method != "PLAIN" { + t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method) + } + if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) { + t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp) + } + _, err = auth.Next([]byte("nonsense"), true) + if err == nil { + t.Fatal("expected second server challange to fail") + } + if !errors.Is(err, ErrUnexpectedServerChallange) { + t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerChallange, err) + } + }) + t.Run("PLAIN authentication on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := PlainAuth("", "user", "pass", TestServerAddr, false) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run("PLAIN authentication on test server should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := PlainAuth("", "user", "pass", TestServerAddr, false) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + }) +} + +func TestPlainAuth_noEnc(t *testing.T) { + tests := []struct { + name string + authName string + server *ServerInfo + shouldFail bool + wantErr error + }{ + { + name: "PLAIN-NOENC auth succeeds", + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + shouldFail: false, + }, + { + // OK to use PlainAuth on localhost without TLS + name: "PLAIN-NOENC on localhost is allowed to go unencrypted", + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + shouldFail: false, + }, + { + // ALSO OK on non-localhost. This auth mode is specificly for that. + name: "PLAIN-NOENC on non-localhost is allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + shouldFail: false, + }, + { + name: "PLAIN-NOENC on non-localhost with no PLAIN announcement, is allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + shouldFail: false, + }, + { + name: "PLAIN-NOENC with wrong hostname", + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + shouldFail: true, + wantErr: ErrWrongHostname, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + identity := "foo" + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + auth := PlainAuth(identity, user, pass, tt.authName, true) + method, resp, err := auth.Start(tt.server) + if err != nil && !tt.shouldFail { + t.Errorf("plain authentication failed: %s", err) + } + if err == nil && tt.shouldFail { + t.Error("plain authentication was expected to fail") + } + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err) + } + return + } + if method != "PLAIN" { + t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method) + } + if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) { + t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp) + } + }) + } + t.Run("PLAIN-NOENC sends second server response should fail", func(t *testing.T) { + identity := "foo" + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + server := &ServerInfo{Name: "servername", TLS: true} + auth := PlainAuth(identity, user, pass, "servername", true) + method, resp, err := auth.Start(server) + if err != nil { + t.Fatalf("plain authentication failed: %s", err) + } + if method != "PLAIN" { + t.Errorf("expected method return to be: %q, got: %q", "PLAIN", method) + } + if !bytes.Equal([]byte(identity+"\x00"+user+"\x00"+pass), resp) { + t.Errorf("expected response to be: %q, got: %q", identity+"\x00"+user+"\x00"+pass, resp) + } + _, err = auth.Next([]byte("nonsense"), true) + if err == nil { + t.Fatal("expected second server challange to fail") + } + if !errors.Is(err, ErrUnexpectedServerChallange) { + t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerChallange, err) + } + }) + t.Run("PLAIN-NOENC authentication on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := PlainAuth("", "user", "pass", TestServerAddr, true) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run("PLAIN-NOENC authentication on test server should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH PLAIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := PlainAuth("", "user", "pass", TestServerAddr, true) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + }) +} + +func TestLoginAuth(t *testing.T) { + tests := []struct { + name string + authName string + server *ServerInfo + shouldFail bool + wantErr error + }{ + { + name: "LOGIN auth succeeds", + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + shouldFail: false, + }, + { + // OK to use PlainAuth on localhost without TLS + name: "LOGIN on localhost is allowed to go unencrypted", + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + shouldFail: false, + }, + { + // NOT OK on non-localhost, even if server says LOGIN is OK. + // (We don't know that the server is the real server.) + name: "LOGIN on non-localhost is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"LOGIN"}}, + shouldFail: true, + wantErr: ErrUnencrypted, + }, + { + name: "LOGIN on non-localhost with no LOGIN announcement, is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + shouldFail: true, + wantErr: ErrUnencrypted, + }, + { + name: "LOGIN with wrong hostname", + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + shouldFail: true, + wantErr: ErrWrongHostname, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + auth := LoginAuth(user, pass, tt.authName, false) + method, _, err := auth.Start(tt.server) + if err != nil && !tt.shouldFail { + t.Errorf("login authentication failed: %s", err) + } + if err == nil && tt.shouldFail { + t.Error("login authentication was expected to fail") + } + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err) + } + return + } + if method != "LOGIN" { + t.Errorf("expected method return to be: %q, got: %q", "LOGIN", method) + } + resp, err := auth.Next([]byte(user), true) if err != nil { - t.Errorf("#%d success message error: %s", i, err) + t.Errorf("failed on first server challange: %s", err) + } + if !bytes.Equal([]byte(user), resp) { + t.Errorf("expected response to first challange to be: %q, got: %q", user, resp) + } + resp, err = auth.Next([]byte(pass), true) + if err != nil { + t.Errorf("failed on second server challange: %s", err) + } + if !bytes.Equal([]byte(pass), resp) { + t.Errorf("expected response to second challange to be: %q, got: %q", pass, resp) + } + _, err = auth.Next([]byte("nonsense"), true) + if err == nil { + t.Error("expected third server challange to fail, but didn't") + } + if !errors.Is(err, ErrUnexpectedServerResponse) { + t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerResponse, err) + } + }) + } + t.Run("LOGIN authentication on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := LoginAuth("user", "pass", TestServerAddr, false) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run("LOGIN authentication on test server should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := LoginAuth("user", "pass", TestServerAddr, false) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + }) +} + +func TestLoginAuth_noEnc(t *testing.T) { + tests := []struct { + name string + authName string + server *ServerInfo + shouldFail bool + wantErr error + }{ + { + name: "LOGIN-NOENC auth succeeds", + authName: "servername", + server: &ServerInfo{Name: "servername", TLS: true}, + shouldFail: false, + }, + { + // OK to use PlainAuth on localhost without TLS + name: "LOGIN-NOENC on localhost is allowed to go unencrypted", + authName: "localhost", + server: &ServerInfo{Name: "localhost", TLS: false}, + shouldFail: false, + }, + { + // ALSO OK on non-localhost. This auth mode is specificly for that. + name: "LOGIN-NOENC on non-localhost is allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"LOGIN"}}, + shouldFail: false, + }, + { + name: "LOGIN-NOENC on non-localhost with no LOGIN announcement, is not allowed to go unencrypted", + authName: "servername", + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + shouldFail: false, + }, + { + name: "LOGIN-NOENC with wrong hostname", + authName: "servername", + server: &ServerInfo{Name: "attacker", TLS: true}, + shouldFail: true, + wantErr: ErrWrongHostname, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user := "toni.tester@example.com" + pass := "v3ryS3Cur3P4ssw0rd" + auth := LoginAuth(user, pass, tt.authName, true) + method, _, err := auth.Start(tt.server) + if err != nil && !tt.shouldFail { + t.Errorf("login authentication failed: %s", err) + } + if err == nil && tt.shouldFail { + t.Error("login authentication was expected to fail") + } + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error to be: %s, got: %s", tt.wantErr, err) + } + return + } + if method != "LOGIN" { + t.Errorf("expected method return to be: %q, got: %q", "LOGIN", method) + } + resp, err := auth.Next([]byte(user), true) + if err != nil { + t.Errorf("failed on first server challange: %s", err) + } + if !bytes.Equal([]byte(user), resp) { + t.Errorf("expected response to first challange to be: %q, got: %q", user, resp) + } + resp, err = auth.Next([]byte(pass), true) + if err != nil { + t.Errorf("failed on second server challange: %s", err) + } + if !bytes.Equal([]byte(pass), resp) { + t.Errorf("expected response to second challange to be: %q, got: %q", pass, resp) + } + _, err = auth.Next([]byte("nonsense"), true) + if err == nil { + t.Error("expected third server challange to fail, but didn't") + } + if !errors.Is(err, ErrUnexpectedServerResponse) { + t.Errorf("expected error to be: %s, got: %s", ErrUnexpectedServerResponse, err) + } + }) + } + t.Run("LOGIN-NOENC authentication on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := LoginAuth("user", "pass", TestServerAddr, true) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run("LOGIN-NOENC authentication on test server should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := LoginAuth("user", "pass", TestServerAddr, true) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + }) +} + +func TestXOAuth2Auth(t *testing.T) { + t.Run("XOAuth2 authentication all steps", func(t *testing.T) { + auth := XOAuth2Auth("user", "token") + proto, toserver, err := auth.Start(&ServerInfo{Name: "servername", TLS: true}) + if err != nil { + t.Fatalf("failed to start XOAuth2 authentication: %s", err) + } + if proto != "XOAUTH2" { + t.Errorf("expected protocol to be XOAUTH2, got: %q", proto) + } + expected := []byte("user=user\x01auth=Bearer token\x01\x01") + if !bytes.Equal(expected, toserver) { + t.Errorf("expected server response to be: %q, got: %q", expected, toserver) + } + resp, err := auth.Next([]byte("nonsense"), true) + if err != nil { + t.Errorf("failed on first server challange: %s", err) + } + if !bytes.Equal([]byte(""), resp) { + t.Errorf("expected server response to be empty, got: %q", resp) + } + _, err = auth.Next([]byte("nonsense"), false) + if err != nil { + t.Errorf("failed on first server challange: %s", err) + } + }) + t.Run("XOAuth2 succeeds with faker", func(t *testing.T) { + server := []string{ + "220 Fake server ready ESMTP", + "250-fake.server", + "250-AUTH XOAUTH2", + "250 8BITMIME", + "235 2.7.0 Accepted", + } + var wrote strings.Builder + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(strings.Join(server, "\r\n")), + &wrote, + } + client, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("failed to create client on faker server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + + auth := XOAuth2Auth("user", "token") + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to faker server: %s", err) + } + + // the Next method returns a nil response. It must not be sent. + // The client request must end with the authentication. + if !strings.HasSuffix(wrote.String(), "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n") { + t.Fatalf("got %q; want AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n", wrote.String()) + } + }) + t.Run("XOAuth2 fails with faker", func(t *testing.T) { + serverResp := []string{ + "220 Fake server ready ESMTP", + "250-fake.server", + "250-AUTH XOAUTH2", + "250 8BITMIME", + "334 eyJzdGF0dXMiOiI0MDAiLCJzY2hlbWVzIjoiQmVhcmVyIiwic2NvcGUiOiJodHRwczovL21haWwuZ29vZ2xlLmNvbS8ifQ==", + "535 5.7.8 Username and Password not accepted", + "221 2.0.0 closing connection", + } + var wrote strings.Builder + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(strings.Join(serverResp, "\r\n")), + &wrote, + } + client, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("failed to create client on faker server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + + auth := XOAuth2Auth("user", "token") + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + resp := strings.Split(wrote.String(), "\r\n") + if len(resp) != 5 { + t.Fatalf("unexpected number of client requests got %d; want 5", len(resp)) + } + if resp[1] != "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=" { + t.Fatalf("got %q; want AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=", resp[1]) + } + // the Next method returns an empty response. It must be sent + if resp[2] != "" { + t.Fatalf("got %q; want empty response", resp[2]) + } + }) + t.Run("XOAuth2 authentication on test server succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH XOAUTH2\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := XOAuth2Auth("user", "token") + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + if err = client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run("XOAuth2 authentication on test server fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH XOAUTH2\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := XOAuth2Auth("user", "token") + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Errorf("expected authentication to fail") + } + }) +} + +func TestScramAuth(t *testing.T) { + tests := []struct { + name string + tls bool + authString string + hash func() hash.Hash + isPlus bool + }{ + {"SCRAM-SHA-1 (no TLS)", false, "SCRAM-SHA-1", sha1.New, false}, + {"SCRAM-SHA-256 (no TLS)", false, "SCRAM-SHA-256", sha256.New, false}, + {"SCRAM-SHA-1 (with TLS)", true, "SCRAM-SHA-1", sha1.New, false}, + {"SCRAM-SHA-256 (with TLS)", true, "SCRAM-SHA-256", sha256.New, false}, + {"SCRAM-SHA-1-PLUS", true, "SCRAM-SHA-1-PLUS", sha1.New, true}, + {"SCRAM-SHA-256-PLUS", true, "SCRAM-SHA-256-PLUS", sha256.New, true}, + } + for _, tt := range tests { + t.Run(tt.name+" succeeds on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := fmt.Sprintf("250-AUTH %s\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8", tt.authString) + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + TestSCRAM: true, + HashFunc: tt.hash, + FeatureSet: featureSet, + ListenPort: serverPort, + SSLListener: tt.tls, + IsSCRAMPlus: tt.isPlus, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + var client *Client + switch tt.tls { + case true: + tlsConfig := getTLSConfig(t) + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig) + if err != nil { + t.Fatalf("failed to dial TLS server: %v", err) + } + client, err = NewClient(conn, TestServerAddr) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + case false: + var err error + client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + } + t.Cleanup(func() { + if err := client.Close(); err != nil { + t.Errorf("failed to close client connection: %s", err) + } + }) + + var auth Auth + switch tt.authString { + case "SCRAM-SHA-1": + auth = ScramSHA1Auth("username", "password") + case "SCRAM-SHA-256": + auth = ScramSHA256Auth("username", "password") + case "SCRAM-SHA-1-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA1PlusAuth("username", "password", tlsConnState) + case "SCRAM-SHA-256-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA256PlusAuth("username", "password", tlsConnState) + default: + t.Fatalf("unexpected auth string: %s", tt.authString) + } + if err := client.Auth(auth); err != nil { + t.Errorf("failed to authenticate to test server: %s", err) + } + }) + t.Run(tt.name+" fails on test server", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := fmt.Sprintf("250-AUTH %s\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8", tt.authString) + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + TestSCRAM: true, + HashFunc: tt.hash, + FeatureSet: featureSet, + ListenPort: serverPort, + SSLListener: tt.tls, + IsSCRAMPlus: tt.isPlus, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + var client *Client + switch tt.tls { + case true: + tlsConfig := getTLSConfig(t) + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", TestServerAddr, serverPort), tlsConfig) + if err != nil { + t.Fatalf("failed to dial TLS server: %v", err) + } + client, err = NewClient(conn, TestServerAddr) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + case false: + var err error + client, err = Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to connect to test server: %s", err) + } + } + + var auth Auth + switch tt.authString { + case "SCRAM-SHA-1": + auth = ScramSHA1Auth("invalid", "password") + case "SCRAM-SHA-256": + auth = ScramSHA256Auth("invalid", "password") + case "SCRAM-SHA-1-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA1PlusAuth("invalid", "password", tlsConnState) + case "SCRAM-SHA-256-PLUS": + tlsConnState, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + auth = ScramSHA256PlusAuth("invalid", "password", tlsConnState) + default: + t.Fatalf("unexpected auth string: %s", tt.authString) + } + if err := client.Auth(auth); err == nil { + t.Error("expected authentication to fail") + } + }) + } + t.Run("ScramAuth_Next with nonsense parameter", func(t *testing.T) { + auth := ScramSHA1Auth("username", "password") + _, err := auth.Next([]byte("x=nonsense"), true) + if err == nil { + t.Fatal("expected authentication to fail") + } + if !errors.Is(err, ErrUnexpectedServerResponse) { + t.Errorf("expected ErrUnexpectedServerResponse, got %s", err) + } + }) +} + +func TestScramAuth_normalizeString(t *testing.T) { + t.Run("normalizeString with invalid input should fail", func(t *testing.T) { + auth := scramAuth{} + value := "\u0000example\uFFFEstring\u001F" + _, err := auth.normalizeString(value) + if err == nil { + t.Fatal("normalizeString should fail on disallowed runes") + } + if !strings.Contains(err.Error(), "precis: disallowed rune encountered") { + t.Errorf("expected error to be %q, got %q", "precis: disallowed rune encountered", err) + } + }) + t.Run("normalizeString on empty string should fail", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.normalizeString("") + if err == nil { + t.Error("normalizeString should fail on disallowed runes") + } + if !strings.Contains(err.Error(), "precis: transformation resulted in empty string") { + t.Errorf("expected error to be %q, got %q", "precis: transformation resulted in empty string", err) + } + }) + t.Run("normalizeUsername with invalid input should fail", func(t *testing.T) { + auth := scramAuth{username: "\u0000example\uFFFEstring\u001F"} + _, err := auth.normalizeUsername() + if err == nil { + t.Error("normalizeUsername should fail on disallowed runes") + } + if !strings.Contains(err.Error(), "precis: disallowed rune encountered") { + t.Errorf("expected error to be %q, got %q", "precis: disallowed rune encountered", err) + } + }) + t.Run("normalizeUsername with empty input should fail", func(t *testing.T) { + auth := scramAuth{username: ""} + _, err := auth.normalizeUsername() + if err == nil { + t.Error("normalizeUsername should fail on empty input") + } + if !strings.Contains(err.Error(), "precis: transformation resulted in empty string") { + t.Errorf("expected error to be %q, got %q", "precis: transformation resulted in empty string", err) + } + }) +} + +func TestScramAuth_initialClientMessage(t *testing.T) { + t.Run("initialClientMessage with invalid username should fail", func(t *testing.T) { + auth := scramAuth{username: "\u0000example\uFFFEstring\u001F"} + _, err := auth.initialClientMessage() + if err == nil { + t.Error("initialClientMessage should fail on disallowed runes") + } + if !strings.Contains(err.Error(), "precis: disallowed rune encountered") { + t.Errorf("expected error to be %q, got %q", "precis: disallowed rune encountered", err) + } + }) + t.Run("initialClientMessage with empty username should fail", func(t *testing.T) { + auth := scramAuth{username: ""} + _, err := auth.initialClientMessage() + if err == nil { + t.Error("initialClientMessage should fail on empty username") + } + if !strings.Contains(err.Error(), "precis: transformation resulted in empty string") { + t.Errorf("expected error to be %q, got %q", "precis: transformation resulted in empty string", err) + } + }) + t.Run("initialClientMessage fails on broken rand.Reader", func(t *testing.T) { + defaultRandReader := rand.Reader + t.Cleanup(func() { rand.Reader = defaultRandReader }) + rand.Reader = &randReader{} + auth := scramAuth{username: "username"} + _, err := auth.initialClientMessage() + if err == nil { + t.Error("initialClientMessage should fail with broken rand.Reader") + } + if !strings.Contains(err.Error(), "unable to generate client secret: broken reader") { + t.Errorf("expected error to be %q, got %q", "unable to generate client secret: broken reader", err) + } + }) +} + +func TestScramAuth_handleServerFirstResponse(t *testing.T) { + t.Run("handleServerFirstResponse fails if not at least 3 parts", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.handleServerFirstResponse([]byte("r=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "not enough fields in the first server response" + if !strings.EqualFold(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with first part does not start with r=", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.handleServerFirstResponse([]byte("x=0,y=0,z=0,r=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "first part of the server response does not start with r=" + if !strings.EqualFold(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with second part does not start with s=", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.handleServerFirstResponse([]byte("r=0,x=0,y=0,z=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "second part of the server response does not start with s=" + if !strings.EqualFold(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with third part does not start with i=", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.handleServerFirstResponse([]byte("r=0,s=0,y=0,z=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "third part of the server response does not start with i=" + if !strings.EqualFold(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with empty nonce", func(t *testing.T) { + auth := scramAuth{} + _, err := auth.handleServerFirstResponse([]byte("r=,s=0,i=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "server nonce does not start with our nonce" + if !strings.EqualFold(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with non-base64 nonce", func(t *testing.T) { + auth := scramAuth{nonce: []byte("Test123")} + _, err := auth.handleServerFirstResponse([]byte("r=Test123,s=0,i=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := "illegal base64 data at input byte 0" + if !strings.Contains(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with non-number iterations", func(t *testing.T) { + auth := scramAuth{nonce: []byte("VGVzdDEyMw==")} + _, err := auth.handleServerFirstResponse([]byte("r=VGVzdDEyMw==,s=VGVzdDEyMw==,i=abc")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := `invalid iterations: strconv.Atoi: parsing "abc": invalid syntax` + if !strings.Contains(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) + t.Run("handleServerFirstResponse fails with invalid password runes", func(t *testing.T) { + auth := scramAuth{ + nonce: []byte("VGVzdDEyMw=="), + username: "username", + password: "\u0000example\uFFFEstring\u001F", + } + _, err := auth.handleServerFirstResponse([]byte("r=VGVzdDEyMw==,s=VGVzdDEyMw==,i=0")) + if err == nil { + t.Error("handleServerFirstResponse should fail on invalid response") + } + expectedErr := `unable to normalize password: failed to normalize string: precis: disallowed rune encountered` + if !strings.Contains(err.Error(), expectedErr) { + t.Errorf("expected error to be %q, got %q", expectedErr, err) + } + }) +} + +func TestCRAMMD5Auth(t *testing.T) { + t.Run("CRAM-MD5 on test server succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH CRAM-MD5\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := CRAMMD5Auth("username", "password") + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Auth(auth); err != nil { + t.Errorf("failed to auth to test server: %s", err) + } + }) + t.Run("CRAM-MD5 on test server fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH CRAM-MD5\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + auth := CRAMMD5Auth("username", "password") + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Auth(auth); err == nil { + t.Error("auth should fail on test server") + } + }) +} + +func TestNewClient(t *testing.T) { + t.Run("new client via Dial succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to create client: %s", err) + } + if err := client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + t.Run("new client via Dial fails on server not started", func(t *testing.T) { + _, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, 64000)) + if err == nil { + t.Error("dial on non-existent server should fail") + } + }) + t.Run("new client fails on server not available", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnDial: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + _, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err == nil { + t.Error("connection to non-available server should fail") + } + }) + t.Run("new client fails on faker that fails on close", func(t *testing.T) { + server := "442 service not available\r\n" + var wrote strings.Builder + var fake faker + fake.failOnClose = true + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(server), + &wrote, + } + _, err := NewClient(fake, "faker.host") + if err == nil { + t.Error("connection to non-available server should fail on close") + } + }) +} + +func TestClient_hello(t *testing.T) { + t.Run("client fails on EHLO but not on HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.hello(); err == nil { + t.Error("helo should fail on test server") + } + }) +} + +func TestClient_Hello(t *testing.T) { + t.Run("normal client HELO/EHLO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Hello(TestServerAddr); err != nil { + t.Errorf("failed to send HELO/EHLO to test server: %s", err) + } + }) + t.Run("client HELO/EHLO with empty name should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Hello(""); err == nil { + t.Error("HELO/EHLO with empty name should fail") + } + }) + t.Run("client HELO/EHLO with newline in name should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Errorf("failed to dial to test server: %s", err) + } + if err = client.Hello(TestServerAddr + "\r\n"); err == nil { + t.Error("HELO/EHLO with newline should fail") + } + }) + t.Run("client double HELO/EHLO should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Hello(TestServerAddr); err != nil { + t.Errorf("failed to send HELO/EHLO to test server: %s", err) + } + if err = client.Hello(TestServerAddr); err == nil { + t.Error("double HELO/EHLO should fail") + } + }) +} + +func TestClient_cmd(t *testing.T) { + t.Run("cmd fails on textproto cmd", func(t *testing.T) { + server := "220 server ready\r\n" + var fake faker + fake.failOnClose = true + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(server), + &failWriter{}, + } + client, err := NewClient(fake, "faker.host") + if err != nil { + t.Errorf("failed to create client: %s", err) + } + _, _, err = client.cmd(250, "HELO faker.host") + if err == nil { + t.Error("cmd should fail on textproto cmd with broken writer") + } + }) +} + +func TestClient_StartTLS(t *testing.T) { + t.Run("normal STARTTLS should succeed", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := getTLSConfig(t) + if err = client.StartTLS(tlsConfig); err != nil { + t.Errorf("failed to initialize STARTTLS session: %s", err) + } + }) + t.Run("STARTTLS fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := getTLSConfig(t) + if err = client.StartTLS(tlsConfig); err == nil { + t.Error("STARTTLS should fail on EHLO") + } + }) + t.Run("STARTTLS fails on server not supporting STARTTLS", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnSTARTTLS: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := getTLSConfig(t) + if err = client.StartTLS(tlsConfig); err == nil { + t.Error("STARTTLS should fail for server not supporting it") + } + }) +} + +func TestClient_TLSConnectionState(t *testing.T) { + t.Run("normal TLS connection should return a state", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := getTLSConfig(t) + tlsConfig.MinVersion = tls.VersionTLS12 + if err = client.StartTLS(tlsConfig); err != nil { + t.Errorf("failed to initialize STARTTLS session: %s", err) + } + state, ok := client.TLSConnectionState() + if !ok { + t.Errorf("failed to get TLS connection state") + } + if state.Version < tls.VersionTLS12 { + t.Errorf("TLS connection state version is %d, should be >= %d", state.Version, tls.VersionTLS12) + } + }) + t.Run("no TLS state on non-TLS connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + _, ok := client.TLSConnectionState() + if ok { + t.Error("non-TLS connection should not have TLS connection state") + } + }) +} + +func TestClient_Verify(t *testing.T) { + t.Run("Verify on existing user succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Verify("toni.tester@example.com"); err != nil { + t.Errorf("failed to verify user: %s", err) + } + }) + t.Run("Verify on non-existing user fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + VRFYUserUnknown: true, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Verify("toni.tester@example.com"); err == nil { + t.Error("verify on non-existing user should fail") + } + }) + t.Run("Verify with newlines should fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Verify("toni.tester@example.com\r\n"); err == nil { + t.Error("verify with new lines should fail") + } + }) + t.Run("Verify should fail on HELO/EHLO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Verify("toni.tester@example.com"); err == nil { + t.Error("verify with new lines should fail") + } + }) +} + +func TestClient_Auth(t *testing.T) { + t.Run("Auth fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err = client.Auth(auth); err == nil { + t.Error("auth should fail on EHLO/HELO") + } + }) + t.Run("Auth fails on auth-start", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + auth := LoginAuth("username", "password", "not.localhost.com", false) + if err = client.Auth(auth); err == nil { + t.Error("auth should fail on auth-start, then on quit") + } + expErr := "wrong host name" + if !strings.EqualFold(expErr, err.Error()) { + t.Errorf("expected error: %q, got: %q", expErr, err.Error()) + } + }) + t.Run("Auth fails on auth-start and then on quit", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-STARTTLS\r\n250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnAuth: true, + FailOnQuit: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + auth := LoginAuth("username", "password", "not.localhost.com", false) + if err = client.Auth(auth); err == nil { + t.Error("auth should fail on auth-start, then on quit") + } + expErr := "wrong host name, 500 5.1.2 Error: quit failed" + if !strings.EqualFold(expErr, err.Error()) { + t.Errorf("expected error: %q, got: %q", expErr, err.Error()) + } + }) + // Issue 17794: don't send a trailing space on AUTH command when there's no password. + t.Run("No trailing space on AUTH when there is no password (Issue 17794)", func(t *testing.T) { + server := "220 hello world\r\n" + + "200 some more" + var wrote strings.Builder + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(server), + &wrote, + } + c, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + c.tls = true + c.didHello = true + _ = c.Auth(toServerEmptyAuth{}) + if err = c.Close(); err != nil { + t.Errorf("close failed: %s", err) + } + if got, want := wrote.String(), "AUTH FOOAUTH\r\n*\r\nQUIT\r\n"; got != want { + t.Errorf("wrote %q; want %q", got, want) + } + }) +} + +func TestClient_Mail(t *testing.T) { + t.Run("normal from address succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Mail("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) + } + }) + t.Run("from address with new lines fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Mail("valid-from@domain.tld\r\n"); err == nil { + t.Error("mail from address with new lines should fail") + } + }) + t.Run("from address fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Mail("valid-from@domain.tld"); err == nil { + t.Error("mail from address should fail on EHLO/HELO") + } + }) + t.Run("from address and server supports 8BITMIME", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-8BITMIME\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Mail("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) + } + expected := "MAIL FROM: BODY=8BITMIME" + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + if !strings.EqualFold(resp[5], expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) + } + }) + t.Run("from address and server supports SMTPUTF8", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-SMTPUTF8\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Mail("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) + } + expected := "MAIL FROM: SMTPUTF8" + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + if !strings.EqualFold(resp[5], expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) + } + }) + t.Run("from address and server supports SMTPUTF8 with unicode address", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-SMTPUTF8\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Mail("valid-from+📧@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) + } + expected := "MAIL FROM: SMTPUTF8" + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + if !strings.EqualFold(resp[5], expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) + } + }) + t.Run("from address and server supports DSN", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + client.dsnmrtype = "FULL" + if err = client.Mail("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) + } + expected := "MAIL FROM: RET=FULL" + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + if !strings.EqualFold(resp[5], expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[5]) + } + }) + t.Run("from address and server supports DSN, SMTPUTF8 and 8BITMIME", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + client.dsnmrtype = "FULL" + if err = client.Mail("valid-from@domain.tld"); err != nil { + t.Errorf("failed to set mail from address: %s", err) + } + expected := "MAIL FROM: BODY=8BITMIME SMTPUTF8 RET=FULL" + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + if !strings.EqualFold(resp[7], expected) { + t.Errorf("expected mail from command to be %q, but sent %q", expected, resp[7]) + } + }) +} + +func TestClient_Rcpt(t *testing.T) { + t.Run("normal recipient address succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250-8BITMIME\r\n250-SMTPUTF8\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Rcpt("valid-to@domain.tld"); err != nil { + t.Errorf("failed to set recipient address: %s", err) + } + }) + t.Run("recipient address with newlines fails", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Rcpt("valid-to@domain.tld\r\n"); err == nil { + t.Error("recpient address with newlines should fail") + } + }) + t.Run("recipient address with DSN", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Hello(TestServerAddr); err != nil { + t.Fatalf("failed to send hello to test server: %s", err) + } + client.dsnrntype = "SUCCESS" + if err = client.Rcpt("valid-to@domain.tld"); err == nil { + t.Error("recpient address with newlines should fail") + } + expected := "RCPT TO: NOTIFY=SUCCESS" + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + if !strings.EqualFold(resp[5], expected) { + t.Errorf("expected rcpt to command to be %q, but sent %q", expected, resp[5]) + } + }) +} + +func TestClient_Data(t *testing.T) { + t.Run("normal mail data transmission succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + writer, err := client.Data() + if err != nil { + t.Fatalf("failed to create data writer: %s", err) + } + t.Cleanup(func() { + if err = writer.Close(); err != nil { + t.Errorf("failed to close data writer: %s", err) + } + }) + if _, err = writer.Write([]byte("test message")); err != nil { + t.Errorf("failed to write data to test server: %s", err) + } + }) + t.Run("mail data transmission fails on DATA command", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnDataInit: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if _, err = client.Data(); err == nil { + t.Error("expected data writer to fail") + } + }) +} + +func TestSendMail(t *testing.T) { + tests := []struct { + name string + featureSet string + hostname string + tlsConfig *tls.Config + props *serverProps + fromAddr string + toAddr string + message []byte + }{ + { + "fail on newline in MAIL FROM address", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{}, + "valid-from@domain.tld\r\n", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on newline in RCPT TO address", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{}, + "valid-from@domain.tld", + "valid-to@domain.tld\r\n", + []byte("test message"), + }, + { + "fail on invalid host address", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + "invalid.invalid-host@domain.tld", + getTLSConfig(t), + &serverProps{}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on EHLO/HELO", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnEhlo: true, FailOnHelo: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on STARTTLS", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + &tls.Config{ServerName: "invalid.invalid-host@domain.tld"}, + &serverProps{}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on no server AUTH support", + "250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on AUTH", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnAuth: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on MAIL FROM", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnMailFrom: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on RCPT TO", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnRcptTo: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on DATA (init phase)", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnDataInit: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + { + "fail on DATA (closing phase)", + "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS", + TestServerAddr, + getTLSConfig(t), + &serverProps{FailOnDataClose: true}, + "valid-from@domain.tld", + "valid-to@domain.tld", + []byte("test message"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + tt.props.ListenPort = int(TestServerPortBase + PortAdder.Load()) + tt.props.FeatureSet = tt.featureSet + go func() { + if err := simpleSMTPServer(ctx, t, tt.props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", tt.hostname, tt.props.ListenPort) + testHookStartTLS = func(config *tls.Config) { + config.ServerName = tt.tlsConfig.ServerName + config.RootCAs = tt.tlsConfig.RootCAs + config.Certificates = tt.tlsConfig.Certificates + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, tt.fromAddr, []string{tt.toAddr}, tt.message); err == nil { + t.Error("expected SendMail to " + tt.name) + } + }) + } + t.Run("full SendMail transaction with TLS and auth", func(t *testing.T) { + want := []string{ + "220 go-mail test server ready ESMTP", + "EHLO localhost", + "250-localhost.localdomain", + "250-AUTH LOGIN", + "250-DSN", + "250 STARTTLS", + "STARTTLS", + "220 Ready to start TLS", + "EHLO localhost", + "250-localhost.localdomain", + "250-AUTH LOGIN", + "250-DSN", + "250 STARTTLS", + "AUTH LOGIN", + "235 2.7.0 Authentication successful", + "MAIL FROM:", + "250 2.0.0 OK", + "RCPT TO:", + "250 2.0.0 OK", + "DATA", + "354 End data with .", + "test message", + ".", + "250 2.0.0 Ok: queued as 1234567890", + "QUIT", + "221 2.0.0 Bye", + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) + testHookStartTLS = func(config *tls.Config) { + testConfig := getTLSConfig(t) + config.ServerName = testConfig.ServerName + config.RootCAs = testConfig.RootCAs + config.Certificates = testConfig.Certificates + } + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, + []byte("test message")); err != nil { + t.Fatalf("failed to send mail: %s", err) + } + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + if len(resp)-1 != len(want) { + t.Fatalf("expected %d lines, but got %d", len(want), len(resp)) + } + for i := 0; i < len(want); i++ { + if !strings.EqualFold(resp[i], want[i]) { + t.Errorf("expected line %d to be %q, but got %q", i, resp[i], want[i]) } } - } + }) + t.Run("full SendMail transaction with leading dots", func(t *testing.T) { + want := []string{ + "220 go-mail test server ready ESMTP", + "EHLO localhost", + "250-localhost.localdomain", + "250-AUTH LOGIN", + "250-DSN", + "250 STARTTLS", + "STARTTLS", + "220 Ready to start TLS", + "EHLO localhost", + "250-localhost.localdomain", + "250-AUTH LOGIN", + "250-DSN", + "250 STARTTLS", + "AUTH LOGIN", + "235 2.7.0 Authentication successful", + "MAIL FROM:", + "250 2.0.0 OK", + "RCPT TO:", + "250 2.0.0 OK", + "DATA", + "354 End data with .", + "From: user@gmail.com", + "To: golang-nuts@googlegroups.com", + "Subject: Hooray for Go", + "", + "Line 1", + "..Leading dot line .", + "Goodbye.", + ".", + "250 2.0.0 Ok: queued as 1234567890", + "QUIT", + "221 2.0.0 Bye", + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-AUTH LOGIN\r\n250-DSN\r\n250 STARTTLS" + echoBuffer := bytes.NewBuffer(nil) + props := &serverProps{ + EchoBuffer: echoBuffer, + FeatureSet: featureSet, + ListenPort: serverPort, + } + go func() { + if err := simpleSMTPServer(ctx, t, props); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + addr := fmt.Sprintf("%s:%d", TestServerAddr, serverPort) + testHookStartTLS = func(config *tls.Config) { + testConfig := getTLSConfig(t) + config.ServerName = testConfig.ServerName + config.RootCAs = testConfig.RootCAs + config.Certificates = testConfig.Certificates + } + message := []byte(`From: user@gmail.com +To: golang-nuts@googlegroups.com +Subject: Hooray for Go + +Line 1 +.Leading dot line . +Goodbye.`) + auth := LoginAuth("username", "password", TestServerAddr, false) + if err := SendMail(addr, auth, "valid-from@domain.tld", []string{"valid-to@domain.tld"}, message); err != nil { + t.Fatalf("failed to send mail: %s", err) + } + props.BufferMutex.RLock() + resp := strings.Split(echoBuffer.String(), "\r\n") + props.BufferMutex.RUnlock() + if len(resp)-1 != len(want) { + t.Errorf("expected %d lines, but got %d", len(want), len(resp)) + } + for i := 0; i < len(want); i++ { + if !strings.EqualFold(resp[i], want[i]) { + t.Errorf("expected line %d to be %q, but got %q", i, resp[i], want[i]) + } + } + }) } -func TestAuthPlain(t *testing.T) { - tests := []struct { - authName string - server *ServerInfo - err string - }{ - { - authName: "servername", - server: &ServerInfo{Name: "servername", TLS: true}, - }, - { - // OK to use PlainAuth on localhost without TLS - authName: "localhost", - server: &ServerInfo{Name: "localhost", TLS: false}, - }, - { - // NOT OK on non-localhost, even if server says PLAIN is OK. - // (We don't know that the server is the real server.) - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, - err: "unencrypted connection", - }, - { - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, - err: "unencrypted connection", - }, - { - authName: "servername", - server: &ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", - }, - } - for i, tt := range tests { - auth := PlainAuth("foo", "bar", "baz", tt.authName, false) - _, _, err := auth.Start(tt.server) - got := "" +func TestClient_Extension(t *testing.T) { + t.Run("extension check fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - got = err.Error() + t.Fatalf("failed to dial to test server: %s", err) } - if got != tt.err { - t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if ok, _ := client.Extension("DSN"); ok { + t.Error("expected client extension check to fail on EHLO/HELO") } - } + }) } -func TestAuthPlainNoEnc(t *testing.T) { - tests := []struct { - authName string - server *ServerInfo - err string - }{ - { - authName: "servername", - server: &ServerInfo{Name: "servername", TLS: true}, - }, - { - // OK to use PlainAuth on localhost without TLS - authName: "localhost", - server: &ServerInfo{Name: "localhost", TLS: false}, - }, - { - // Also OK on non-TLS secured connections. The NoEnc mechanism is meant to allow - // non-encrypted connections. - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, - }, - { - authName: "servername", - server: &ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", - }, - } - for i, tt := range tests { - auth := PlainAuth("foo", "bar", "baz", tt.authName, true) - _, _, err := auth.Start(tt.server) - got := "" +func TestClient_Reset(t *testing.T) { + t.Run("reset on functioning client conneciton", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - got = err.Error() + t.Fatalf("failed to dial to test server: %s", err) } - if got != tt.err { - t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Reset(); err != nil { + t.Errorf("failed to reset client: %s", err) } - } -} - -func TestAuthLogin(t *testing.T) { - tests := []struct { - authName string - server *ServerInfo - err string - }{ - { - authName: "servername", - server: &ServerInfo{Name: "servername", TLS: true}, - }, - { - // OK to use LoginAuth on localhost without TLS - authName: "localhost", - server: &ServerInfo{Name: "localhost", TLS: false}, - }, - { - // NOT OK on non-localhost, even if server says PLAIN is OK. - // (We don't know that the server is the real server.) - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"LOGIN"}}, - err: "unencrypted connection", - }, - { - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, - err: "unencrypted connection", - }, - { - authName: "servername", - server: &ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", - }, - } - for i, tt := range tests { - auth := LoginAuth("foo", "bar", tt.authName, false) - _, _, err := auth.Start(tt.server) - got := "" + }) + t.Run("reset fails on RSET", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnReset: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - got = err.Error() + t.Fatalf("failed to dial to test server: %s", err) } - if got != tt.err { - t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Reset(); err == nil { + t.Error("expected client reset to fail") } - } -} - -func TestAuthLoginNoEnc(t *testing.T) { - tests := []struct { - authName string - server *ServerInfo - err string - }{ - { - authName: "servername", - server: &ServerInfo{Name: "servername", TLS: true}, - }, - { - // OK to use LoginAuth on localhost without TLS - authName: "localhost", - server: &ServerInfo{Name: "localhost", TLS: false}, - }, - { - // Also OK on non-TLS secured connections. The NoEnc mechanism is meant to allow - // non-encrypted connections. - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"LOGIN"}}, - }, - { - authName: "servername", - server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, - }, - { - authName: "servername", - server: &ServerInfo{Name: "attacker", TLS: true}, - err: "wrong host name", - }, - } - for i, tt := range tests { - auth := LoginAuth("foo", "bar", tt.authName, true) - _, _, err := auth.Start(tt.server) - got := "" + }) + t.Run("reset fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) if err != nil { - got = err.Error() + t.Fatalf("failed to dial to test server: %s", err) } - if got != tt.err { - t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Reset(); err == nil { + t.Error("expected client reset to fail") } + }) +} + +func TestClient_Noop(t *testing.T) { + t.Run("noop on functioning client conneciton", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Noop(); err != nil { + t.Errorf("failed client no-operation: %s", err) + } + }) + t.Run("noop fails on EHLO/HELO", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnEhlo: true, + FailOnHelo: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Noop(); err == nil { + t.Error("expected client no-operation to fail") + } + }) + t.Run("noop fails on NOOP", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FailOnNoop: true, + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.Noop(); err == nil { + t.Error("expected client no-operation to fail") + } + }) +} + +func TestClient_SetDebugLog(t *testing.T) { + t.Run("set debug loggging to on with no logger defined", func(t *testing.T) { + client := &Client{} + client.SetDebugLog(true) + if !client.debug { + t.Fatalf("expected debug log to be true") + } + if client.logger == nil { + t.Fatalf("expected logger to be defined") + } + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.Stdlog") { + t.Errorf("expected logger to be of type *log.Stdlog, got: %T", client.logger) + } + }) + t.Run("set debug logggin to off with no logger defined", func(t *testing.T) { + client := &Client{} + client.SetDebugLog(false) + if client.debug { + t.Fatalf("expected debug log to be false") + } + if client.logger != nil { + t.Fatalf("expected logger to be nil") + } + }) + t.Run("set active logging to off should cancel out logger", func(t *testing.T) { + client := &Client{debug: true, logger: log.New(os.Stderr, log.LevelDebug)} + client.SetDebugLog(false) + if client.debug { + t.Fatalf("expected debug log to be false") + } + if client.logger != nil { + t.Fatalf("expected logger to be nil") + } + }) +} + +func TestClient_SetLogger(t *testing.T) { + t.Run("set logger to Stdlog logger", func(t *testing.T) { + client := &Client{} + client.SetLogger(log.New(os.Stderr, log.LevelDebug)) + if !strings.EqualFold(fmt.Sprintf("%T", client.logger), "*log.Stdlog") { + t.Errorf("expected logger to be of type *log.Stdlog, got: %T", client.logger) + } + }) +} + +func TestClient_SetLogAuthData(t *testing.T) { + t.Run("set log auth data to true", func(t *testing.T) { + client := &Client{} + client.SetLogAuthData() + if !client.logAuthData { + t.Fatalf("expected log auth data to be true") + } + }) +} + +func TestClient_SetDSNRcptNotifyOption(t *testing.T) { + tests := []string{"NEVER", "SUCCESS", "FAILURE", "DELAY"} + for _, test := range tests { + t.Run("set dsn rcpt notify option to "+test, func(t *testing.T) { + client := &Client{} + client.SetDSNRcptNotifyOption(test) + if !strings.EqualFold(client.dsnrntype, test) { + t.Errorf("expected dsn rcpt notify option to be %s, got %s", test, client.dsnrntype) + } + }) } } -func TestXOAuth2OK(t *testing.T) { - server := []string{ - "220 Fake server ready ESMTP", - "250-fake.server", - "250-AUTH XOAUTH2", - "250 8BITMIME", - "235 2.7.0 Accepted", - } - var wrote strings.Builder - var fake faker - fake.ReadWriter = struct { - io.Reader - io.Writer - }{ - strings.NewReader(strings.Join(server, "\r\n")), - &wrote, +func TestClient_SetDSNMailReturnOption(t *testing.T) { + tests := []string{"HDRS", "FULL"} + for _, test := range tests { + t.Run("set dsn mail return option to "+test, func(t *testing.T) { + client := &Client{} + client.SetDSNMailReturnOption(test) + if !strings.EqualFold(client.dsnmrtype, test) { + t.Errorf("expected dsn mail return option to be %s, got %s", test, client.dsnmrtype) + } + }) } +} - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - defer func() { - if err := c.Close(); err != nil { +func TestClient_HasConnection(t *testing.T) { + t.Run("client has connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if !client.HasConnection() { + t.Error("expected client to have a connection") + } + }) + t.Run("client has no connection", func(t *testing.T) { + client := &Client{} + if client.HasConnection() { + t.Error("expected client to have no connection") + } + }) + t.Run("client has no connection after close", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Close(); err != nil { t.Errorf("failed to close client: %s", err) } - }() - - auth := XOAuth2Auth("user", "token") - err = c.Auth(auth) - if err != nil { - t.Fatalf("XOAuth2 error: %v", err) - } - // the Next method returns a nil response. It must not be sent. - // The client request must end with the authentication. - if !strings.HasSuffix(wrote.String(), "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n") { - t.Fatalf("got %q; want AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=\r\n", wrote.String()) - } + if client.HasConnection() { + t.Error("expected client to have no connection after close") + } + }) + t.Run("client has no connection after quit", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Quit(); err != nil { + t.Errorf("failed to quit client: %s", err) + } + if client.HasConnection() { + t.Error("expected client to have no connection after quit") + } + }) } -func TestXOAuth2Error(t *testing.T) { - serverResp := []string{ - "220 Fake server ready ESMTP", - "250-fake.server", - "250-AUTH XOAUTH2", - "250 8BITMIME", - "334 eyJzdGF0dXMiOiI0MDAiLCJzY2hlbWVzIjoiQmVhcmVyIiwic2NvcGUiOiJodHRwczovL21haWwuZ29vZ2xlLmNvbS8ifQ==", - "535 5.7.8 Username and Password not accepted", - "221 2.0.0 closing connection", - } - var wrote strings.Builder - var fake faker - fake.ReadWriter = struct { - io.Reader - io.Writer - }{ - strings.NewReader(strings.Join(serverResp, "\r\n")), - &wrote, - } - - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - defer func() { - if err := c.Close(); err != nil { +func TestClient_UpdateDeadline(t *testing.T) { + t.Run("update deadline on sane client succeeds", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + if err = client.UpdateDeadline(time.Millisecond * 500); err != nil { + t.Errorf("failed to update connection deadline: %s", err) + } + }) + t.Run("update deadline on no connection should fail", func(t *testing.T) { + client := &Client{} + var err error + if err = client.UpdateDeadline(time.Millisecond * 500); err == nil { + t.Error("expected client deadline update to fail on no connection") + } + expError := "smtp: client has no connection" + if !strings.EqualFold(err.Error(), expError) { + t.Errorf("expected error to be %q, got: %q", expError, err) + } + }) + t.Run("update deadline on closed client should fail", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + if err = client.Close(); err != nil { t.Errorf("failed to close client: %s", err) } - }() - - auth := XOAuth2Auth("user", "token") - err = c.Auth(auth) - if err == nil { - t.Fatal("expected auth error, got nil") - } - client := strings.Split(wrote.String(), "\r\n") - if len(client) != 5 { - t.Fatalf("unexpected number of client requests got %d; want 5", len(client)) - } - if client[1] != "AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=" { - t.Fatalf("got %q; want AUTH XOAUTH2 dXNlcj11c2VyAWF1dGg9QmVhcmVyIHRva2VuAQE=", client[1]) - } - // the Next method returns an empty response. It must be sent - if client[2] != "" { - t.Fatalf("got %q; want empty response", client[2]) - } + if err = client.UpdateDeadline(time.Millisecond * 500); err == nil { + t.Error("expected client deadline update to fail on closed client") + } + }) } -func TestAuthSCRAMSHA1_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2585" - - go func() { - startSMTPServer(false, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA1Auth("username", "password")); err != nil { - t.Errorf("failed to authenticate: %v", err) - } +func TestClient_GetTLSConnectionState(t *testing.T) { + t.Run("get state on sane client connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250-DSN\r\n250 STARTTLS" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + tlsConfig := getTLSConfig(t) + tlsConfig.MinVersion = tls.VersionTLS12 + tlsConfig.MaxVersion = tls.VersionTLS12 + if err = client.StartTLS(tlsConfig); err != nil { + t.Fatalf("failed to start TLS on client: %s", err) + } + state, err := client.GetTLSConnectionState() + if err != nil { + t.Fatalf("failed to get TLS connection state: %s", err) + } + if state == nil { + t.Fatal("expected TLS connection state to be non-nil") + } + if state.Version != tls.VersionTLS12 { + t.Errorf("expected TLS connection state version to be %d, got: %d", tls.VersionTLS12, state.Version) + } + }) + t.Run("get state on no connection", func(t *testing.T) { + client := &Client{} + _, err := client.GetTLSConnectionState() + if err == nil { + t.Fatal("expected client to have no tls connection state") + } + }) + t.Run("get state on non-tls client connection", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 DSN" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + _, err = client.GetTLSConnectionState() + if err == nil { + t.Error("expected client to have no tls connection state") + } + }) + t.Run("fail to get state on non-tls connection with tls flag set", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + PortAdder.Add(1) + serverPort := int(TestServerPortBase + PortAdder.Load()) + featureSet := "250 DSN" + go func() { + if err := simpleSMTPServer(ctx, t, &serverProps{ + FeatureSet: featureSet, + ListenPort: serverPort, + }, + ); err != nil { + t.Errorf("failed to start test server: %s", err) + return + } + }() + time.Sleep(time.Millisecond * 30) + client, err := Dial(fmt.Sprintf("%s:%d", TestServerAddr, serverPort)) + if err != nil { + t.Fatalf("failed to dial to test server: %s", err) + } + client.tls = true + t.Cleanup(func() { + if err = client.Close(); err != nil { + t.Errorf("failed to close client: %s", err) + } + }) + _, err = client.GetTLSConnectionState() + if err == nil { + t.Error("expected client to have no tls connection state") + } + }) } -func TestAuthSCRAMSHA256_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2586" - - go func() { - startSMTPServer(false, hostname, port, sha256.New) - }() - time.Sleep(time.Millisecond * 500) - - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA256Auth("username", "password")); err != nil { - t.Errorf("failed to authenticate: %v", err) - } -} - -func TestAuthSCRAMSHA1PLUS_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2590" - - go func() { - startSMTPServer(true, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA1PlusAuth("username", "password", &tlsConnState)); err != nil { - t.Errorf("failed to authenticate: %v", err) - } -} - -func TestAuthSCRAMSHA256PLUS_OK(t *testing.T) { - hostname := "127.0.0.1" - port := "2591" - - go func() { - startSMTPServer(true, hostname, port, sha256.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA256PlusAuth("username", "password", &tlsConnState)); err != nil { - t.Errorf("failed to authenticate: %v", err) - } -} - -func TestAuthSCRAMSHA1_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2587" - - go func() { - startSMTPServer(false, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA1Auth("username", "invalid")); err == nil { - t.Errorf("expected auth error, got nil") - } -} - -func TestAuthSCRAMSHA256_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2588" - - go func() { - startSMTPServer(false, hostname, port, sha256.New) - }() - time.Sleep(time.Millisecond * 500) - - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - if err = client.Auth(ScramSHA256Auth("username", "invalid")); err == nil { - t.Errorf("expected auth error, got nil") - } -} - -func TestAuthSCRAMSHA1PLUS_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2592" - - go func() { - startSMTPServer(true, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA1PlusAuth("username", "invalid", &tlsConnState)); err == nil { - t.Errorf("expected auth error, got nil") - } -} - -func TestAuthSCRAMSHA256PLUS_fail(t *testing.T) { - hostname := "127.0.0.1" - port := "2593" - - go func() { - startSMTPServer(true, hostname, port, sha1.New) - }() - time.Sleep(time.Millisecond * 500) - - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return - } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true} - - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%s", hostname, port), &tlsConfig) - if err != nil { - t.Errorf("failed to dial server: %v", err) - } - client, err := NewClient(conn, hostname) - if err != nil { - t.Errorf("failed to create client: %v", err) - } - if err = client.Hello(hostname); err != nil { - t.Errorf("failed to send HELO: %v", err) - } - tlsConnState := conn.ConnectionState() - if err = client.Auth(ScramSHA256PlusAuth("username", "invalid", &tlsConnState)); err == nil { - t.Errorf("expected auth error, got nil") - } -} - -// Issue 17794: don't send a trailing space on AUTH command when there's no password. -func TestClientAuthTrimSpace(t *testing.T) { - server := "220 hello world\r\n" + - "200 some more" - var wrote strings.Builder - var fake faker - fake.ReadWriter = struct { - io.Reader - io.Writer - }{ - strings.NewReader(server), - &wrote, - } - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - c.tls = true - c.didHello = true - _ = c.Auth(toServerEmptyAuth{}) - if err := c.Close(); err != nil { - t.Errorf("close failed: %s", err) - } - if got, want := wrote.String(), "AUTH FOOAUTH\r\n*\r\nQUIT\r\n"; got != want { - t.Errorf("wrote %q; want %q", got, want) - } -} - -// toServerEmptyAuth is an implementation of Auth that only implements -// the Start method, and returns "FOOAUTH", nil, nil. Notably, it returns -// zero bytes for "toServer" so we can test that we don't send spaces at -// the end of the line. See TestClientAuthTrimSpace. -type toServerEmptyAuth struct{} - -func (toServerEmptyAuth) Start(_ *ServerInfo) (proto string, toServer []byte, err error) { - return "FOOAUTH", nil, nil -} - -func (toServerEmptyAuth) Next(_ []byte, _ bool) (toServer []byte, err error) { - panic("unexpected call") +func TestClient_debugLog(t *testing.T) { + t.Run("debug log is enabled", func(t *testing.T) { + buffer := bytes.NewBuffer(nil) + logger := log.New(buffer, log.LevelDebug) + client := &Client{logger: logger, debug: true} + client.debugLog(log.DirClientToServer, "%s", "simple string") + client.debugLog(log.DirServerToClient, "%d", 1234) + want := "DEBUG: C --> S: simple string" + if !strings.Contains(buffer.String(), want) { + t.Errorf("expected debug log to contain %q, got: %q", want, buffer.String()) + } + want = "DEBUG: C <-- S: 1234" + if !strings.Contains(buffer.String(), want) { + t.Errorf("expected debug log to contain %q, got: %q", want, buffer.String()) + } + }) + t.Run("debug log is disable", func(t *testing.T) { + buffer := bytes.NewBuffer(nil) + logger := log.New(buffer, log.LevelDebug) + client := &Client{logger: logger, debug: false} + client.debugLog(log.DirClientToServer, "%s", "simple string") + client.debugLog(log.DirServerToClient, "%d", 1234) + if buffer.Len() > 0 { + t.Errorf("expected debug log to be empty, got: %q", buffer.String()) + } + }) } +// faker is a struct embedding io.ReadWriter to simulate network connections for testing purposes. type faker struct { io.ReadWriter + failOnClose bool } -func (f faker) Close() error { return nil } +func (f faker) Close() error { + if f.failOnClose { + return fmt.Errorf("faker: failed to close connection") + } + return nil +} func (f faker) LocalAddr() net.Addr { return nil } func (f faker) RemoteAddr() net.Addr { return nil } func (f faker) SetDeadline(time.Time) error { return nil } func (f faker) SetReadDeadline(time.Time) error { return nil } func (f faker) SetWriteDeadline(time.Time) error { return nil } -func TestBasic(t *testing.T) { - server := strings.Join(strings.Split(basicServer, "\n"), "\r\n") - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c := &Client{Text: textproto.NewConn(fake), localName: "localhost"} - - if err := c.helo(); err != nil { - t.Fatalf("HELO failed: %s", err) - } - if err := c.ehlo(); err == nil { - t.Fatalf("Expected first EHLO to fail") - } - if err := c.ehlo(); err != nil { - t.Fatalf("Second EHLO failed: %s", err) - } - - c.didHello = true - if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" { - t.Fatalf("Expected AUTH supported") - } - if ok, _ := c.Extension("DSN"); ok { - t.Fatalf("Shouldn't support DSN") - } - - if err := c.Mail("user@gmail.com"); err == nil { - t.Fatalf("MAIL should require authentication") - } - - if err := c.Verify("user1@gmail.com"); err == nil { - t.Fatalf("First VRFY: expected no verification") - } - if err := c.Verify("user2@gmail.com>\r\nDATA\r\nAnother injected message body\r\n.\r\nQUIT\r\n"); err == nil { - t.Fatalf("VRFY should have failed due to a message injection attempt") - } - if err := c.Verify("user2@gmail.com"); err != nil { - t.Fatalf("Second VRFY: expected verification, got %s", err) - } - - // fake TLS so authentication won't complain - c.tls = true - c.serverName = "smtp.google.com" - if err := c.Auth(PlainAuth("", "user", "pass", "smtp.google.com", false)); err != nil { - t.Fatalf("AUTH failed: %s", err) - } - - if err := c.Rcpt("golang-nuts@googlegroups.com>\r\nDATA\r\nInjected message body\r\n.\r\nQUIT\r\n"); err == nil { - t.Fatalf("RCPT should have failed due to a message injection attempt") - } - if err := c.Mail("user@gmail.com>\r\nDATA\r\nAnother injected message body\r\n.\r\nQUIT\r\n"); err == nil { - t.Fatalf("MAIL should have failed due to a message injection attempt") - } - if err := c.Mail("user@gmail.com"); err != nil { - t.Fatalf("MAIL failed: %s", err) - } - if err := c.Rcpt("golang-nuts@googlegroups.com"); err != nil { - t.Fatalf("RCPT failed: %s", err) - } - msg := `From: user@gmail.com -To: golang-nuts@googlegroups.com -Subject: Hooray for Go - -Line 1 -.Leading dot line . -Goodbye.` - w, err := c.Data() - if err != nil { - t.Fatalf("DATA failed: %s", err) - } - if _, err := w.Write([]byte(msg)); err != nil { - t.Fatalf("Data write failed: %s", err) - } - if err := w.Close(); err != nil { - t.Fatalf("Bad data response: %s", err) - } - - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } -} - -var basicServer = `250 mx.google.com at your service -502 Unrecognized command. -250-mx.google.com at your service -250-SIZE 35651584 -250-AUTH LOGIN PLAIN -250 8BITMIME -530 Authentication required -252 Send some mail, I'll try my best -250 User is valid -235 Accepted -250 Sender OK -250 Receiver OK -354 Go ahead -250 Data OK -221 OK -` - -var basicClient = `HELO localhost -EHLO localhost -EHLO localhost -MAIL FROM: BODY=8BITMIME -VRFY user1@gmail.com -VRFY user2@gmail.com -AUTH PLAIN AHVzZXIAcGFzcw== -MAIL FROM: BODY=8BITMIME -RCPT TO: -DATA -From: user@gmail.com -To: golang-nuts@googlegroups.com -Subject: Hooray for Go - -Line 1 -..Leading dot line . -Goodbye. -. -QUIT -` - -func TestHELOFailed(t *testing.T) { - serverLines := `502 EH? -502 EH? -221 OK -` - clientLines := `EHLO localhost -HELO localhost -QUIT -` - server := strings.Join(strings.Split(serverLines, "\n"), "\r\n") - client := strings.Join(strings.Split(clientLines, "\n"), "\r\n") - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c := &Client{Text: textproto.NewConn(fake), localName: "localhost"} - if err := c.Hello("localhost"); err == nil { - t.Fatal("expected EHLO to fail") - } - if err := c.Quit(); err != nil { - t.Errorf("QUIT failed: %s", err) - } - _ = bcmdbuf.Flush() - actual := cmdbuf.String() - if client != actual { - t.Errorf("Got:\n%s\nWant:\n%s", actual, client) - } -} - -func TestExtensions(t *testing.T) { - fake := func(server string) (c *Client, bcmdbuf *bufio.Writer, cmdbuf *strings.Builder) { - server = strings.Join(strings.Split(server, "\n"), "\r\n") - - cmdbuf = &strings.Builder{} - bcmdbuf = bufio.NewWriter(cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c = &Client{Text: textproto.NewConn(fake), localName: "localhost"} - - return c, bcmdbuf, cmdbuf - } - - t.Run("helo", func(t *testing.T) { - const ( - basicServer = `250 mx.google.com at your service -250 Sender OK -221 Goodbye -` - - basicClient = `HELO localhost -MAIL FROM: -QUIT -` - ) - - c, bcmdbuf, cmdbuf := fake(basicServer) - - if err := c.helo(); err != nil { - t.Fatalf("HELO failed: %s", err) - } - c.didHello = true - if err := c.Mail("user@gmail.com"); err != nil { - t.Fatalf("MAIL FROM failed: %s", err) - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - }) - - t.Run("ehlo", func(t *testing.T) { - const ( - basicServer = `250-mx.google.com at your service -250 SIZE 35651584 -250 Sender OK -221 Goodbye -` - - basicClient = `EHLO localhost -MAIL FROM: -QUIT -` - ) - - c, bcmdbuf, cmdbuf := fake(basicServer) - - if err := c.Hello("localhost"); err != nil { - t.Fatalf("EHLO failed: %s", err) - } - if ok, _ := c.Extension("8BITMIME"); ok { - t.Fatalf("Shouldn't support 8BITMIME") - } - if ok, _ := c.Extension("SMTPUTF8"); ok { - t.Fatalf("Shouldn't support SMTPUTF8") - } - if err := c.Mail("user@gmail.com"); err != nil { - t.Fatalf("MAIL FROM failed: %s", err) - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - }) - - t.Run("ehlo 8bitmime", func(t *testing.T) { - const ( - basicServer = `250-mx.google.com at your service -250-SIZE 35651584 -250 8BITMIME -250 Sender OK -221 Goodbye -` - - basicClient = `EHLO localhost -MAIL FROM: BODY=8BITMIME -QUIT -` - ) - - c, bcmdbuf, cmdbuf := fake(basicServer) - - if err := c.Hello("localhost"); err != nil { - t.Fatalf("EHLO failed: %s", err) - } - if ok, _ := c.Extension("8BITMIME"); !ok { - t.Fatalf("Should support 8BITMIME") - } - if ok, _ := c.Extension("SMTPUTF8"); ok { - t.Fatalf("Shouldn't support SMTPUTF8") - } - if err := c.Mail("user@gmail.com"); err != nil { - t.Fatalf("MAIL FROM failed: %s", err) - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - actualcmds := cmdbuf.String() - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - }) - - t.Run("ehlo smtputf8", func(t *testing.T) { - const ( - basicServer = `250-mx.google.com at your service -250-SIZE 35651584 -250 SMTPUTF8 -250 Sender OK -221 Goodbye -` - - basicClient = `EHLO localhost -MAIL FROM: SMTPUTF8 -QUIT -` - ) - - c, bcmdbuf, cmdbuf := fake(basicServer) - - if err := c.Hello("localhost"); err != nil { - t.Fatalf("EHLO failed: %s", err) - } - if ok, _ := c.Extension("8BITMIME"); ok { - t.Fatalf("Shouldn't support 8BITMIME") - } - if ok, _ := c.Extension("SMTPUTF8"); !ok { - t.Fatalf("Should support SMTPUTF8") - } - if err := c.Mail("user+📧@gmail.com"); err != nil { - t.Fatalf("MAIL FROM failed: %s", err) - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - actualcmds := cmdbuf.String() - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - }) - - t.Run("ehlo 8bitmime smtputf8", func(t *testing.T) { - const ( - basicServer = `250-mx.google.com at your service -250-SIZE 35651584 -250-8BITMIME -250 SMTPUTF8 -250 Sender OK -221 Goodbye - ` - - basicClient = `EHLO localhost -MAIL FROM: BODY=8BITMIME SMTPUTF8 -QUIT -` - ) - - c, bcmdbuf, cmdbuf := fake(basicServer) - - if err := c.Hello("localhost"); err != nil { - t.Fatalf("EHLO failed: %s", err) - } - c.didHello = true - if ok, _ := c.Extension("8BITMIME"); !ok { - t.Fatalf("Should support 8BITMIME") - } - if ok, _ := c.Extension("SMTPUTF8"); !ok { - t.Fatalf("Should support SMTPUTF8") - } - if err := c.Mail("user+📧@gmail.com"); err != nil { - t.Fatalf("MAIL FROM failed: %s", err) - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - actualcmds := cmdbuf.String() - client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - }) -} - -func TestNewClient(t *testing.T) { - server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") - client := strings.Join(strings.Split(newClientClient, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - out := func() string { - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - return cmdbuf.String() - } - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v\n(after %v)", err, out()) - } - defer func() { - _ = c.Close() - }() - if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" { - t.Fatalf("Expected AUTH supported") - } - if ok, _ := c.Extension("DSN"); ok { - t.Fatalf("Shouldn't support DSN") - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - actualcmds := out() - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } -} - -// TestClient_SetDebugLog tests the Client method with the Client.SetDebugLog method -// to enable debug logging -func TestClient_SetDebugLog(t *testing.T) { - server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - out := func() string { - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - return cmdbuf.String() - } - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v\n(after %v)", err, out()) - } - defer func() { - _ = c.Close() - }() - c.SetDebugLog(true) - if !c.debug { - t.Errorf("Expected DebugLog flag to be true but received false") - } -} - -// TestClient_SetLogger tests the Client method with the Client.SetLogger method -// to provide a custom logger -func TestClient_SetLogger(t *testing.T) { - server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - out := func() string { - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - return cmdbuf.String() - } - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v\n(after %v)", err, out()) - } - defer func() { - _ = c.Close() - }() - c.SetLogger(log.New(os.Stderr, log.LevelDebug)) - if c.logger == nil { - t.Errorf("Expected Logger to be set but received nil") - } - c.logger.Debugf(log.Log{Direction: log.DirServerToClient, Format: "%s", Messages: []interface{}{"test"}}) - c.SetLogger(nil) - c.logger.Debugf(log.Log{Direction: log.DirServerToClient, Format: "%s", Messages: []interface{}{"test"}}) -} - -func TestClient_SetLogAuthData(t *testing.T) { - server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - out := func() string { - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("failed to flush: %s", err) - } - return cmdbuf.String() - } - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v\n(after %v)", err, out()) - } - defer func() { - _ = c.Close() - }() - c.SetLogAuthData() - if !c.logAuthData { - t.Error("Expected logAuthData to be true but received false") - } -} - -var newClientServer = `220 hello world -250-mx.google.com at your service -250-SIZE 35651584 -250-AUTH LOGIN PLAIN -250 8BITMIME -221 OK -` - -var newClientClient = `EHLO localhost -QUIT -` - -func TestNewClient2(t *testing.T) { - server := strings.Join(strings.Split(newClient2Server, "\n"), "\r\n") - client := strings.Join(strings.Split(newClient2Client, "\n"), "\r\n") - - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - defer func() { - _ = c.Close() - }() - if ok, _ := c.Extension("DSN"); ok { - t.Fatalf("Shouldn't support DSN") - } - if err := c.Quit(); err != nil { - t.Fatalf("QUIT failed: %s", err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - if client != actualcmds { - t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } -} - -var newClient2Server = `220 hello world -502 EH? -250-mx.google.com at your service -250-SIZE 35651584 -250-AUTH LOGIN PLAIN -250 8BITMIME -221 OK -` - -var newClient2Client = `EHLO localhost -HELO localhost -QUIT -` - -func TestNewClientWithTLS(t *testing.T) { - cert, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - t.Fatalf("loadcert: %v", err) - } - - config := tls.Config{Certificates: []tls.Certificate{cert}} - - ln, err := tls.Listen("tcp", "127.0.0.1:0", &config) - if err != nil { - ln, err = tls.Listen("tcp", "[::1]:0", &config) - if err != nil { - t.Fatalf("server: listen: %v", err) - } - } - - go func() { - conn, err := ln.Accept() - if err != nil { - t.Errorf("server: accept: %v", err) - return - } - defer func() { - _ = conn.Close() - }() - - _, err = conn.Write([]byte("220 SIGNS\r\n")) - if err != nil { - t.Errorf("server: write: %v", err) - return - } - }() - - config.InsecureSkipVerify = true - conn, err := tls.Dial("tcp", ln.Addr().String(), &config) - if err != nil { - t.Fatalf("client: dial: %v", err) - } - defer func() { - _ = conn.Close() - }() - - client, err := NewClient(conn, ln.Addr().String()) - if err != nil { - t.Fatalf("smtp: newclient: %v", err) - } - if !client.tls { - t.Errorf("client.tls Got: %t Expected: %t", client.tls, true) - } -} - -func TestHello(t *testing.T) { - if len(helloServer) != len(helloClient) { - t.Fatalf("Hello server and client size mismatch") - } - - tf := func(fake faker, i int) error { - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - defer func() { - _ = c.Close() - }() - c.localName = "customhost" - err = nil - - switch i { - case 0: - err = c.Hello("hostinjection>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n") - if err == nil { - t.Errorf("Expected Hello to be rejected due to a message injection attempt") - } - err = c.Hello("customhost") - case 1: - err = c.StartTLS(nil) - if err.Error() == "502 Not implemented" { - err = nil - } - case 2: - err = c.Verify("test@example.com") - case 3: - c.tls = true - c.serverName = "smtp.google.com" - err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com", false)) - case 4: - err = c.Mail("test@example.com") - case 5: - ok, _ := c.Extension("feature") - if ok { - t.Errorf("Expected FEATURE not to be supported") - } - case 6: - err = c.Reset() - case 7: - err = c.Quit() - case 8: - err = c.Verify("test@example.com") - if err != nil { - err = c.Hello("customhost") - if err != nil { - t.Errorf("Want error, got none") - } - } - case 9: - err = c.Noop() - default: - t.Fatalf("Unhandled command") - } - - if err != nil { - t.Errorf("Command %d failed: %v", i, err) - } - return nil - } - - for i := 0; i < len(helloServer); i++ { - server := strings.Join(strings.Split(baseHelloServer+helloServer[i], "\n"), "\r\n") - client := strings.Join(strings.Split(baseHelloClient+helloClient[i], "\n"), "\r\n") - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - - if err := tf(fake, i); err != nil { - t.Error(err) - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - if client != actualcmds { - t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } - } -} - -var baseHelloServer = `220 hello world -502 EH? -250-mx.google.com at your service -250 FEATURE -` - -var helloServer = []string{ - "", - "502 Not implemented\n", - "250 User is valid\n", - "235 Accepted\n", - "250 Sender ok\n", - "", - "250 Reset ok\n", - "221 Goodbye\n", - "250 Sender ok\n", - "250 ok\n", -} - -var baseHelloClient = `EHLO customhost -HELO customhost -` - -var helloClient = []string{ - "", - "STARTTLS\n", - "VRFY test@example.com\n", - "AUTH PLAIN AHVzZXIAcGFzcw==\n", - "MAIL FROM:\n", - "", - "RSET\n", - "QUIT\n", - "VRFY test@example.com\n", - "NOOP\n", -} - -func TestSendMail(t *testing.T) { - server := strings.Join(strings.Split(sendMailServer, "\n"), "\r\n") - client := strings.Join(strings.Split(sendMailClient, "\n"), "\r\n") - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Unable to create listener: %v", err) - } - defer func() { - _ = l.Close() - }() - - // prevent data race on bcmdbuf - done := make(chan struct{}) - go func(data []string) { - defer close(done) - - conn, err := l.Accept() - if err != nil { - t.Errorf("Accept error: %v", err) - return - } - defer func() { - _ = conn.Close() - }() - - tc := textproto.NewConn(conn) - for i := 0; i < len(data) && data[i] != ""; i++ { - if err := tc.PrintfLine("%s", data[i]); err != nil { - t.Errorf("printing to textproto failed: %s", err) - } - for len(data[i]) >= 4 && data[i][3] == '-' { - i++ - if err := tc.PrintfLine("%s", data[i]); err != nil { - t.Errorf("printing to textproto failed: %s", err) - } - } - if data[i] == "221 Goodbye" { - return - } - read := false - for !read || data[i] == "354 Go ahead" { - msg, err := tc.ReadLine() - if _, err := bcmdbuf.Write([]byte(msg + "\r\n")); err != nil { - t.Errorf("write failed: %s", err) - } - read = true - if err != nil { - t.Errorf("Read error: %v", err) - return - } - if data[i] == "354 Go ahead" && msg == "." { - break - } - } - } - }(strings.Split(server, "\r\n")) - - err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n"}, []byte(strings.Replace(`From: test@example.com -To: other@example.com -Subject: SendMail test - -SendMail is working for me. -`, "\n", "\r\n", -1))) - if err == nil { - t.Errorf("Expected SendMail to be rejected due to a message injection attempt") - } - - err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com -To: other@example.com -Subject: SendMail test - -SendMail is working for me. -`, "\n", "\r\n", -1))) - if err != nil { - t.Errorf("%v", err) - } - - <-done - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - if client != actualcmds { - t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } -} - -var sendMailServer = `220 hello world -502 EH? -250 mx.google.com at your service -250 Sender ok -250 Receiver ok -354 Go ahead -250 Data ok -221 Goodbye -` - -var sendMailClient = `EHLO localhost -HELO localhost -MAIL FROM: -RCPT TO: -DATA -From: test@example.com -To: other@example.com -Subject: SendMail test - -SendMail is working for me. -. -QUIT -` - -func TestSendMailWithAuth(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("Unable to create listener: %v", err) - } - defer func() { - _ = l.Close() - }() - - errCh := make(chan error) - go func() { - defer close(errCh) - conn, err := l.Accept() - if err != nil { - errCh <- fmt.Errorf("listener Accept: %w", err) - return - } - defer func() { - _ = conn.Close() - }() - - tc := textproto.NewConn(conn) - if err := tc.PrintfLine("220 hello world"); err != nil { - t.Errorf("textproto connetion print failed: %s", err) - } - msg, err := tc.ReadLine() - if err != nil { - errCh <- fmt.Errorf("textproto connection ReadLine error: %w", err) - return - } - const wantMsg = "EHLO localhost" - if msg != wantMsg { - errCh <- fmt.Errorf("unexpected response %q; want %q", msg, wantMsg) - return - } - err = tc.PrintfLine("250 mx.google.com at your service") - if err != nil { - errCh <- fmt.Errorf("textproto connection PrintfLine: %w", err) - return - } - }() - - err = SendMail(l.Addr().String(), PlainAuth("", "user", "pass", "smtp.google.com", false), "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com -To: other@example.com -Subject: SendMail test - -SendMail is working for me. -`, "\n", "\r\n", -1))) - if err == nil { - t.Error("SendMail: Server doesn't support AUTH, expected to get an error, but got none ") - return - } - if err.Error() != "smtp: server doesn't support AUTH" { - t.Errorf("Expected: smtp: server doesn't support AUTH, got: %s", err) - } - err = <-errCh - if err != nil { - t.Fatalf("server error: %v", err) - } -} - -func TestAuthFailed(t *testing.T) { - server := strings.Join(strings.Split(authFailedServer, "\n"), "\r\n") - client := strings.Join(strings.Split(authFailedClient, "\n"), "\r\n") - var cmdbuf strings.Builder - bcmdbuf := bufio.NewWriter(&cmdbuf) - var fake faker - fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) - c, err := NewClient(fake, "fake.host") - if err != nil { - t.Fatalf("NewClient: %v", err) - } - defer func() { - _ = c.Close() - }() - - c.tls = true - c.serverName = "smtp.google.com" - err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com", false)) - - if err == nil { - t.Error("Auth: expected error; got none") - } else if err.Error() != "535 Invalid credentials\nplease see www.example.com" { - t.Errorf("Auth: got error: %v, want: %s", err, "535 Invalid credentials\nplease see www.example.com") - } - - if err := bcmdbuf.Flush(); err != nil { - t.Errorf("flush failed: %s", err) - } - actualcmds := cmdbuf.String() - if client != actualcmds { - t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client) - } -} - -var authFailedServer = `220 hello world -250-mx.google.com at your service -250 AUTH LOGIN PLAIN -535-Invalid credentials -535 please see www.example.com -221 Goodbye -` - -var authFailedClient = `EHLO localhost -AUTH PLAIN AHVzZXIAcGFzcw== -* -QUIT -` - -func TestTLSClient(t *testing.T) { - if runtime.GOOS == "freebsd" || runtime.GOOS == "js" || runtime.GOOS == "wasip1" { - SkipFlaky(t, 19229) - } - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - errc := make(chan error) - go func() { - errc <- sendMail(ln.Addr().String()) - }() - conn, err := ln.Accept() - if err != nil { - t.Fatalf("failed to accept connection: %v", err) - } - defer func() { - _ = conn.Close() - }() - if err := serverHandle(conn, t); err != nil { - t.Fatalf("failed to handle connection: %v", err) - } - if err := <-errc; err != nil { - t.Fatalf("client error: %v", err) - } -} - -func TestTLSConnState(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - cfg := &tls.Config{ServerName: "example.com"} - testHookStartTLS(cfg) // set the RootCAs - if err := c.StartTLS(cfg); err != nil { - t.Errorf("StartTLS: %v", err) - return - } - cs, ok := c.TLSConnectionState() - if !ok { - t.Errorf("TLSConnectionState returned ok == false; want true") - return - } - if cs.Version == 0 || !cs.HandshakeComplete { - t.Errorf("ConnectionState = %#v; expect non-zero Version and HandshakeComplete", cs) - } - }() - <-clientDone - <-serverDone -} - -func TestClient_GetTLSConnectionState(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - cfg := &tls.Config{ServerName: "example.com"} - testHookStartTLS(cfg) // set the RootCAs - if err := c.StartTLS(cfg); err != nil { - t.Errorf("StartTLS: %v", err) - return - } - cs, err := c.GetTLSConnectionState() - if err != nil { - t.Errorf("failed to get TLSConnectionState: %s", err) - return - } - if cs.Version == 0 || !cs.HandshakeComplete { - t.Errorf("ConnectionState = %#v; expect non-zero Version and HandshakeComplete", cs) - } - }() - <-clientDone - <-serverDone -} - -func TestClient_GetTLSConnectionState_noTLS(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - _, err = c.GetTLSConnectionState() - if err == nil { - t.Error("GetTLSConnectionState: expected error; got nil") - return - } - }() - <-clientDone - <-serverDone -} - -func TestClient_GetTLSConnectionState_noConn(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - _ = c.Close() - _, err = c.GetTLSConnectionState() - if err == nil { - t.Error("GetTLSConnectionState: expected error; got nil") - return - } - }() - <-clientDone - <-serverDone -} - -func TestClient_GetTLSConnectionState_unableErr(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - c.tls = true - _, err = c.GetTLSConnectionState() - if err == nil { - t.Error("GetTLSConnectionState: expected error; got nil") - return - } - }() - <-clientDone - <-serverDone -} - -func TestClient_HasConnection(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - cfg := &tls.Config{ServerName: "example.com"} - testHookStartTLS(cfg) // set the RootCAs - if err := c.StartTLS(cfg); err != nil { - t.Errorf("StartTLS: %v", err) - return - } - if !c.HasConnection() { - t.Error("HasConnection: expected true; got false") - return - } - if err = c.Quit(); err != nil { - t.Errorf("closing connection failed: %s", err) - return - } - if c.HasConnection() { - t.Error("HasConnection: expected false; got true") - } - }() - <-clientDone - <-serverDone -} - -func TestClient_SetDSNMailReturnOption(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - c.SetDSNMailReturnOption("foo") - if c.dsnmrtype != "foo" { - t.Errorf("SetDSNMailReturnOption: expected %s; got %s", "foo", c.dsnrntype) - } - }() - <-clientDone - <-serverDone -} - -func TestClient_SetDSNRcptNotifyOption(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err := serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Quit() - }() - c.SetDSNRcptNotifyOption("foo") - if c.dsnrntype != "foo" { - t.Errorf("SetDSNMailReturnOption: expected %s; got %s", "foo", c.dsnrntype) - } - }() - <-clientDone - <-serverDone -} - -func TestClient_UpdateDeadline(t *testing.T) { - ln := newLocalListener(t) - defer func() { - _ = ln.Close() - }() - clientDone := make(chan bool) - serverDone := make(chan bool) - go func() { - defer close(serverDone) - c, err := ln.Accept() - if err != nil { - t.Errorf("Server accept: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if err = serverHandle(c, t); err != nil { - t.Errorf("server error: %v", err) - } - }() - go func() { - defer close(clientDone) - c, err := Dial(ln.Addr().String()) - if err != nil { - t.Errorf("Client dial: %v", err) - return - } - defer func() { - _ = c.Close() - }() - if !c.HasConnection() { - t.Error("HasConnection: expected true; got false") - return - } - if err = c.UpdateDeadline(time.Millisecond * 20); err != nil { - t.Errorf("failed to update deadline: %s", err) - return - } - time.Sleep(time.Millisecond * 50) - if !c.HasConnection() { - t.Error("HasConnection: expected true; got false") - return - } - }() - <-clientDone - <-serverDone -} - -func newLocalListener(t *testing.T) net.Listener { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - ln, err = net.Listen("tcp6", "[::1]:0") - } - if err != nil { - t.Fatal(err) - } - return ln -} - -type smtpSender struct { - w io.Writer -} - -func (s smtpSender) send(f string) { - _, _ = s.w.Write([]byte(f + "\r\n")) -} - -// smtp server, finely tailored to deal with our own client only! -func serverHandle(c net.Conn, t *testing.T) error { - send := smtpSender{c}.send - send("220 127.0.0.1 ESMTP service ready") - s := bufio.NewScanner(c) - tf := func(config *tls.Config) error { - c = tls.Server(c, config) - defer func() { - _ = c.Close() - }() - return serverHandleTLS(c, t) - } - for s.Scan() { - switch s.Text() { - case "EHLO localhost": - send("250-127.0.0.1 ESMTP offers a warm hug of welcome") - send("250-STARTTLS") - send("250 Ok") - case "STARTTLS": - send("220 Go ahead") - keypair, err := tls.X509KeyPair(localhostCert, localhostKey) - if err != nil { - return err - } - config := &tls.Config{Certificates: []tls.Certificate{keypair}} - return tf(config) - case "QUIT": - return nil - default: - t.Fatalf("unrecognized command: %q", s.Text()) - } - } - return s.Err() -} - -func serverHandleTLS(c net.Conn, t *testing.T) error { - send := smtpSender{c}.send - s := bufio.NewScanner(c) - for s.Scan() { - switch s.Text() { - case "EHLO localhost": - send("250 Ok") - case "MAIL FROM:": - send("250 Ok") - case "RCPT TO:": - send("250 Ok") - case "DATA": - send("354 send the mail data, end with .") - send("250 Ok") - case "Subject: test": - case "": - case "howdy!": - case ".": - case "QUIT": - send("221 127.0.0.1 Service closing transmission channel") - return nil - default: - t.Fatalf("unrecognized command during TLS: %q", s.Text()) - } - } - return s.Err() -} - -func init() { - testRootCAs := x509.NewCertPool() - testRootCAs.AppendCertsFromPEM(localhostCert) - testHookStartTLS = func(config *tls.Config) { - config.RootCAs = testRootCAs - } -} - -func sendMail(hostPort string) error { - from := "joe1@example.com" - to := []string{"joe2@example.com"} - return SendMail(hostPort, nil, from, to, []byte("Subject: test\n\nhowdy!")) -} - -// localhostCert is a PEM-encoded TLS cert generated from src/crypto/tls: -// -// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com \ -// --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h -var localhostCert = []byte(` ------BEGIN CERTIFICATE----- -MIICFDCCAX2gAwIBAgIRAK0xjnaPuNDSreeXb+z+0u4wDQYJKoZIhvcNAQELBQAw -EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2 -MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw -gYkCgYEA0nFbQQuOWsjbGtejcpWz153OlziZM4bVjJ9jYruNw5n2Ry6uYQAffhqa -JOInCmmcVe2siJglsyH9aRh6vKiobBbIUXXUU1ABd56ebAzlt0LobLlx7pZEMy30 -LqIi9E6zmL3YvdGzpYlkFRnRrqwEtWYbGBf3znO250S56CCWH2UCAwEAAaNoMGYw -DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF -MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA -AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAbZtDS2dVuBYvb+MnolWnCNqvw1w5Gtgi -NmvQQPOMgM3m+oQSCPRTNGSg25e1Qbo7bgQDv8ZTnq8FgOJ/rbkyERw2JckkHpD4 -n4qcK27WkEDBtQFlPihIM8hLIuzWoi/9wygiElTy/tVL3y7fGCvY2/k1KBthtZGF -tN8URjVmyEo= ------END CERTIFICATE-----`) - -// localhostKey is the private key for localhostCert. -var localhostKey = []byte(testingKey(` ------BEGIN RSA TESTING KEY----- -MIICXgIBAAKBgQDScVtBC45ayNsa16NylbPXnc6XOJkzhtWMn2Niu43DmfZHLq5h -AB9+Gpok4icKaZxV7ayImCWzIf1pGHq8qKhsFshRddRTUAF3np5sDOW3QuhsuXHu -lkQzLfQuoiL0TrOYvdi90bOliWQVGdGurAS1ZhsYF/fOc7bnRLnoIJYfZQIDAQAB -AoGBAMst7OgpKyFV6c3JwyI/jWqxDySL3caU+RuTTBaodKAUx2ZEmNJIlx9eudLA -kucHvoxsM/eRxlxkhdFxdBcwU6J+zqooTnhu/FE3jhrT1lPrbhfGhyKnUrB0KKMM -VY3IQZyiehpxaeXAwoAou6TbWoTpl9t8ImAqAMY8hlULCUqlAkEA+9+Ry5FSYK/m -542LujIcCaIGoG1/Te6Sxr3hsPagKC2rH20rDLqXwEedSFOpSS0vpzlPAzy/6Rbb -PHTJUhNdwwJBANXkA+TkMdbJI5do9/mn//U0LfrCR9NkcoYohxfKz8JuhgRQxzF2 -6jpo3q7CdTuuRixLWVfeJzcrAyNrVcBq87cCQFkTCtOMNC7fZnCTPUv+9q1tcJyB -vNjJu3yvoEZeIeuzouX9TJE21/33FaeDdsXbRhQEj23cqR38qFHsF1qAYNMCQQDP -QXLEiJoClkR2orAmqjPLVhR3t2oB3INcnEjLNSq8LHyQEfXyaFfu4U9l5+fRPL2i -jiC0k/9L5dHUsF0XZothAkEA23ddgRs+Id/HxtojqqUT27B8MT/IGNrYsp4DvS/c -qgkeluku4GjxRlDMBuXk94xOBEinUs+p/hwP1Alll80Tpg== ------END RSA TESTING KEY-----`)) - +// testingKey replaces the substring "TESTING KEY" with "PRIVATE KEY" in the given string s. func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } -var flaky = flag.Bool("flaky", false, "run known-flaky tests too") +// serverProps represents the configuration properties for the SMTP server. +type serverProps struct { + BufferMutex sync.RWMutex + EchoBuffer io.Writer + FailOnAuth bool + FailOnDataInit bool + FailOnDataClose bool + FailOnDial bool + FailOnEhlo bool + FailOnHelo bool + FailOnMailFrom bool + FailOnNoop bool + FailOnQuit bool + FailOnReset bool + FailOnRcptTo bool + FailOnSTARTTLS bool + FailTemp bool + FeatureSet string + ListenPort int + HashFunc func() hash.Hash + IsSCRAMPlus bool + IsTLS bool + SupportDSN bool + SSLListener bool + TestSCRAM bool + VRFYUserUnknown bool +} -func SkipFlaky(t testing.TB, issue int) { +// simpleSMTPServer starts a simple TCP server that resonds to SMTP commands. +// The provided featureSet represents in what the server responds to EHLO command +// failReset controls if a RSET succeeds +func simpleSMTPServer(ctx context.Context, t *testing.T, props *serverProps) error { t.Helper() - if !*flaky { - t.Skipf("skipping known flaky test without the -flaky flag; see golang.org/issue/%d", issue) + if props == nil { + return fmt.Errorf("no server properties provided") } -} -// testSCRAMSMTPServer represents a test server for SCRAM-based SMTP authentication. -// It does not do any acutal computation of the challenges but verifies that the expected -// fields are present. We have actual real authentication tests for all SCRAM modes in the -// go-mail client_test.go -type testSCRAMSMTPServer struct { - authMechanism string - nonce string - hostname string - port string - tlsServer bool - h func() hash.Hash -} + var listener net.Listener + var err error + if props.SSLListener { + keypair, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + return fmt.Errorf("failed to read TLS keypair: %w", err) + } + tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}} + listener, err = tls.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort), + tlsConfig) + if err != nil { + t.Fatalf("failed to create TLS listener: %s", err) + } + } else { + listener, err = net.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort)) + } + if err != nil { + return fmt.Errorf("unable to listen on %s://%s: %w (SSL: %t)", TestServerProto, TestServerAddr, err, + props.SSLListener) + } -func (s *testSCRAMSMTPServer) handleConnection(conn net.Conn) { defer func() { - _ = conn.Close() + if err := listener.Close(); err != nil { + t.Logf("failed to close listener: %s", err) + } }() - reader := bufio.NewReader(conn) - writer := bufio.NewWriter(conn) - writeLine := func(data string) error { + for { + select { + case <-ctx.Done(): + return nil + default: + connection, err := listener.Accept() + var opErr *net.OpError + if err != nil { + if errors.As(err, &opErr) && opErr.Temporary() { + continue + } + return fmt.Errorf("unable to accept connection: %w", err) + } + handleTestServerConnection(connection, t, props) + } + } +} + +func handleTestServerConnection(connection net.Conn, t *testing.T, props *serverProps) { + t.Helper() + if !props.IsTLS { + t.Cleanup(func() { + if err := connection.Close(); err != nil { + t.Logf("failed to close connection: %s", err) + } + }) + } + + reader := bufio.NewReader(connection) + writer := bufio.NewWriter(connection) + + writeLine := func(data string) { _, err := writer.WriteString(data + "\r\n") if err != nil { - return fmt.Errorf("unable to write line: %w", err) + t.Logf("failed to write line: %s", err) } - return writer.Flush() + if props.EchoBuffer != nil { + props.BufferMutex.Lock() + if _, berr := props.EchoBuffer.Write([]byte(data + "\r\n")); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) + } + props.BufferMutex.Unlock() + } + _ = writer.Flush() } writeOK := func() { - _ = writeLine("250 2.0.0 OK") + writeLine("250 2.0.0 OK") } - if err := writeLine("220 go-mail test server ready ESMTP"); err != nil { - return - } - - data, err := reader.ReadString('\n') - if err != nil { - return - } - data = strings.TrimSpace(data) - if strings.HasPrefix(data, "EHLO") { - _ = writeLine(fmt.Sprintf("250-%s", s.hostname)) - _ = writeLine("250-AUTH SCRAM-SHA-1 SCRAM-SHA-256") - writeOK() - } else { - _ = writeLine("500 Invalid command") - return + if !props.IsTLS { + if props.FailOnDial { + writeLine("421 4.4.1 Service not available") + return + } + writeLine("220 go-mail test server ready ESMTP") } for { - data, err = reader.ReadString('\n') + data, err := reader.ReadString('\n') if err != nil { - fmt.Printf("failed to read data: %v", err) + break } - data = strings.TrimSpace(data) - if strings.HasPrefix(data, "AUTH") { - parts := strings.Split(data, " ") - if len(parts) < 2 { - _ = writeLine("500 Syntax error") - return + time.Sleep(time.Millisecond) + if props.EchoBuffer != nil { + props.BufferMutex.Lock() + if _, berr := props.EchoBuffer.Write([]byte(data)); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) } + props.BufferMutex.Unlock() + } - authMechanism := parts[1] - if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" && - authMechanism != "SCRAM-SHA-1-PLUS" && authMechanism != "SCRAM-SHA-256-PLUS" { - _ = writeLine("504 Unrecognized authentication mechanism") - return + var datastring string + data = strings.TrimSpace(data) + switch { + case strings.HasPrefix(data, "HELO"): + if len(strings.Split(data, " ")) != 2 { + writeLine("501 Syntax: HELO hostname") + break } - s.authMechanism = authMechanism - _ = writeLine("334 ") - s.handleSCRAMAuth(conn) + if props.FailOnHelo { + writeLine("500 5.5.2 Error: fail on HELO") + break + } + if props.FeatureSet != "" { + writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) + break + } + writeLine("250 localhost.localdomain\r\n") + case strings.HasPrefix(data, "EHLO"): + if len(strings.Split(data, " ")) != 2 { + writeLine("501 Syntax: EHLO hostname") + break + } + if props.FailOnEhlo { + writeLine("500 5.5.2 Error: fail on EHLO") + break + } + if props.FeatureSet != "" { + writeLine("250-localhost.localdomain\r\n" + props.FeatureSet) + break + } + writeLine("250 localhost.localdomain\r\n") + case strings.HasPrefix(data, "MAIL FROM:"): + if props.FailOnMailFrom { + writeLine("500 5.5.2 Error: fail on MAIL FROM") + break + } + from := strings.TrimPrefix(data, "MAIL FROM:") + from = strings.ReplaceAll(from, "BODY=8BITMIME", "") + from = strings.ReplaceAll(from, "SMTPUTF8", "") + if props.SupportDSN { + from = strings.ReplaceAll(from, "RET=FULL", "") + } + from = strings.TrimSpace(from) + if !strings.HasPrefix(from, "") { + writeLine(fmt.Sprintf("503 5.1.2 Invalid from: %s", from)) + break + } + writeOK() + case strings.HasPrefix(data, "RCPT TO:"): + if props.FailOnRcptTo { + writeLine("500 5.5.2 Error: fail on RCPT TO") + break + } + to := strings.TrimPrefix(data, "RCPT TO:") + if props.SupportDSN { + to = strings.ReplaceAll(to, "NOTIFY=FAILURE,SUCCESS", "") + } + to = strings.TrimSpace(to) + if !strings.EqualFold(to, "") { + writeLine(fmt.Sprintf("500 5.1.2 Invalid to: %s", to)) + break + } + writeOK() + case strings.HasPrefix(data, "AUTH"): + if props.FailOnAuth { + writeLine("535 5.7.8 Error: authentication failed") + break + } + if props.TestSCRAM { + parts := strings.Split(data, " ") + authMechanism := parts[1] + if authMechanism != "SCRAM-SHA-1" && authMechanism != "SCRAM-SHA-256" && + authMechanism != "SCRAM-SHA-1-PLUS" && authMechanism != "SCRAM-SHA-256-PLUS" { + writeLine("504 Unrecognized authentication mechanism") + break + } + scram := &testSCRAMSMTP{ + tlsServer: props.IsSCRAMPlus, + h: props.HashFunc, + } + writeLine("334 ") + scram.handleSCRAMAuth(connection) + break + } + writeLine("235 2.7.0 Authentication successful") + case strings.EqualFold(data, "DATA"): + if props.FailOnDataInit { + writeLine("503 5.5.1 Error: fail on DATA init") + break + } + writeLine("354 End data with .") + for { + ddata, derr := reader.ReadString('\n') + if derr != nil { + t.Logf("failed to read data from connection: %s", derr) + break + } + if props.EchoBuffer != nil { + props.BufferMutex.Lock() + if _, berr := props.EchoBuffer.Write([]byte(ddata)); berr != nil { + t.Errorf("failed write to echo buffer: %s", berr) + } + props.BufferMutex.Unlock() + } + ddata = strings.TrimSpace(ddata) + if ddata == "." { + if props.FailOnDataClose { + writeLine("500 5.0.0 Error during DATA transmission") + break + } + if props.FailTemp { + writeLine("451 4.3.0 Error: fail on DATA close") + break + } + writeLine("250 2.0.0 Ok: queued as 1234567890") + break + } + datastring += ddata + "\n" + } + case strings.EqualFold(data, "noop"): + if props.FailOnNoop { + writeLine("500 5.0.0 Error: fail on NOOP") + break + } + writeOK() + case strings.HasPrefix(data, "VRFY"): + if props.VRFYUserUnknown { + writeLine("550 5.1.1 User unknown") + break + } + parts := strings.SplitN(data, " ", 2) + if len(parts) != 2 { + writeLine("500 5.0.0 Error: invalid syntax for VRFY") + break + } + writeLine(fmt.Sprintf("250 2.0.0 Ok: %s OK", parts[1])) + case strings.EqualFold(data, "rset"): + if props.FailOnReset { + writeLine("500 5.1.2 Error: reset failed") + break + } + writeOK() + case strings.EqualFold(data, "quit"): + if props.FailOnQuit { + writeLine("500 5.1.2 Error: quit failed") + break + } + writeLine("221 2.0.0 Bye") return - } else { - _ = writeLine("500 Invalid command") + case strings.EqualFold(data, "starttls"): + if props.FailOnSTARTTLS { + writeLine("500 5.1.2 Error: starttls failed") + break + } + keypair, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + writeLine("500 5.1.2 Error: starttls failed - " + err.Error()) + break + } + writeLine("220 Ready to start TLS") + tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}, ServerName: "example.com"} + connection = tls.Server(connection, tlsConfig) + props.IsTLS = true + handleTestServerConnection(connection, t, props) + default: + writeLine("500 5.5.2 Error: bad syntax - " + data) } } } -func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { +// testSCRAMSMTP represents a part of the test server for SCRAM-based SMTP authentication. +// It does not do any acutal computation of the challenges but verifies that the expected +// fields are present. We have actual real authentication tests for all SCRAM modes in the +// go-mail client_test.go +type testSCRAMSMTP struct { + nonce string + h func() hash.Hash + tlsServer bool +} + +func (s *testSCRAMSMTP) handleSCRAMAuth(conn net.Conn) { reader := bufio.NewReader(conn) writer := bufio.NewWriter(conn) writeLine := func(data string) error { @@ -2511,7 +4019,7 @@ func (s *testSCRAMSMTPServer) handleSCRAMAuth(conn net.Conn) { _ = writeLine("235 Authentication successful") } -func (s *testSCRAMSMTPServer) extractNonce(message string) string { +func (s *testSCRAMSMTP) extractNonce(message string) string { parts := strings.Split(message, ",") for _, part := range parts { if strings.HasPrefix(part, "r=") { @@ -2521,37 +4029,47 @@ func (s *testSCRAMSMTPServer) extractNonce(message string) string { return "" } -func startSMTPServer(tlsServer bool, hostname, port string, h func() hash.Hash) { - server := &testSCRAMSMTPServer{ - hostname: hostname, - port: port, - tlsServer: tlsServer, - h: h, - } - listener, err := net.Listen("tcp", fmt.Sprintf("%s:%s", hostname, port)) - if err != nil { - fmt.Printf("Failed to start SMTP server: %v", err) - } - defer func() { - _ = listener.Close() - }() +// randReader is type that satisfies the io.Reader interface. It can fail on a specific read +// operations and is therefore useful to test consecutive reads with errors +type randReader struct{} +// Read implements the io.Reader interface for the randReader type +func (r *randReader) Read([]byte) (int, error) { + return 0, errors.New("broken reader") +} + +// toServerEmptyAuth is an implementation of Auth that only implements +// the Start method, and returns "FOOAUTH", nil, nil. Notably, it returns +// zero bytes for "toServer" so we can test that we don't send spaces at +// the end of the line. See TestClientAuthTrimSpace. +type toServerEmptyAuth struct{} + +func (toServerEmptyAuth) Start(_ *ServerInfo) (proto string, toServer []byte, err error) { + return "FOOAUTH", nil, nil +} + +func (toServerEmptyAuth) Next(_ []byte, _ bool) (toServer []byte, err error) { + return nil, fmt.Errorf("unexpected call") +} + +// failWriter is a struct type that implements the io.Writer interface, but always returns an error on Write. +type failWriter struct{} + +func (w *failWriter) Write([]byte) (int, error) { + return 0, errors.New("broken writer") +} + +func getTLSConfig(t *testing.T) *tls.Config { + t.Helper() cert, err := tls.X509KeyPair(localhostCert, localhostKey) if err != nil { - fmt.Printf("error creating TLS cert: %s", err) - return + t.Fatalf("unable to load host certifcate: %s", err) } - tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}} - - for { - conn, err := listener.Accept() - if err != nil { - fmt.Printf("Failed to accept connection: %v", err) - continue - } - if server.tlsServer { - conn = tls.Server(conn, &tlsConfig) - } - go server.handleConnection(conn) + testRootCAs := x509.NewCertPool() + testRootCAs.AppendCertsFromPEM(localhostCert) + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: testRootCAs, + ServerName: "example.com", } } diff --git a/sonar-project.properties b/sonar-project.properties index e228a11..b800d26 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2022-2023 The go-mail Authors # -# SPDX-License-Identifier: CC0-1.0 +# SPDX-License-Identifier: MIT sonar.projectKey=go-mail sonar.go.coverage.reportPaths=cov.out