mirror of
https://github.com/wneessen/go-mail.git
synced 2024-12-18 08:50:39 +01:00
Add SCRAM-SHA support to SMTP authentication
Introduced additional SMTP authentication mechanisms: SCRAM-SHA-1, SCRAM-SHA-1-PLUS, SCRAM-SHA-256, and SCRAM-SHA-256-PLUS. Added corresponding error messages for unsupported authentication types. This enhances security options for SMTP connections.
This commit is contained in:
parent
9069c9cdff
commit
e8fc6cd78f
1 changed files with 280 additions and 0 deletions
280
smtp/auth_scram.go
Normal file
280
smtp/auth_scram.go
Normal file
|
@ -0,0 +1,280 @@
|
|||
// SPDX-FileCopyrightText: Copyright (c) 2023 The go-mail Authors
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/text/secure/precis"
|
||||
)
|
||||
|
||||
type scramAuth struct {
|
||||
username, password, algorithm string
|
||||
firstBareMsg, nonce, saltedPwd, authMessage []byte
|
||||
iterations int
|
||||
h func() hash.Hash
|
||||
isPlus bool
|
||||
tlsConnState *tls.ConnectionState
|
||||
bindData []byte
|
||||
}
|
||||
|
||||
func ScramSHA256Auth(username, password string) Auth {
|
||||
return &scramAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
algorithm: "SCRAM-SHA-256",
|
||||
h: sha256.New,
|
||||
}
|
||||
}
|
||||
|
||||
func ScramSHA256PlusAuth(username, password string, tlsConnState *tls.ConnectionState) Auth {
|
||||
return &scramAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
algorithm: "SCRAM-SHA-256-PLUS",
|
||||
h: sha256.New,
|
||||
isPlus: true,
|
||||
tlsConnState: tlsConnState,
|
||||
}
|
||||
}
|
||||
|
||||
func ScramSHA1Auth(username, password string) Auth {
|
||||
return &scramAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
algorithm: "SCRAM-SHA-1",
|
||||
h: sha1.New,
|
||||
}
|
||||
}
|
||||
|
||||
func ScramSHA1PlusAuth(username, password string, tlsConnState *tls.ConnectionState) Auth {
|
||||
return &scramAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
algorithm: "SCRAM-SHA-1-PLUS",
|
||||
h: sha1.New,
|
||||
isPlus: true,
|
||||
tlsConnState: tlsConnState,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *scramAuth) Start(_ *ServerInfo) (string, []byte, error) {
|
||||
fmt.Printf("algo: %s\n", a.algorithm)
|
||||
return a.algorithm, nil, nil
|
||||
}
|
||||
|
||||
func (a *scramAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
if more {
|
||||
if len(fromServer) == 0 {
|
||||
a.reset()
|
||||
return a.initialClientMessage()
|
||||
}
|
||||
switch {
|
||||
case bytes.HasPrefix(fromServer, []byte("r=")):
|
||||
resp, err := a.handleServerFirstResponse(fromServer)
|
||||
if err != nil {
|
||||
a.reset()
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
case bytes.HasPrefix(fromServer, []byte("v=")):
|
||||
resp, err := a.handleServerValidationMessage(fromServer)
|
||||
if err != nil {
|
||||
a.reset()
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
default:
|
||||
a.reset()
|
||||
return nil, errors.New("unexpected server response")
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *scramAuth) reset() {
|
||||
a.nonce = nil
|
||||
a.firstBareMsg = nil
|
||||
a.saltedPwd = nil
|
||||
a.authMessage = nil
|
||||
a.iterations = 0
|
||||
}
|
||||
|
||||
func (a *scramAuth) initialClientMessage() ([]byte, error) {
|
||||
username, err := a.normalizeUsername()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("username normalization failed: %w", err)
|
||||
}
|
||||
|
||||
nonceBuffer := make([]byte, 24)
|
||||
if _, err := io.ReadFull(rand.Reader, nonceBuffer); err != nil {
|
||||
return nil, fmt.Errorf("unable to generate client secret: %w", err)
|
||||
}
|
||||
a.nonce = make([]byte, base64.StdEncoding.EncodedLen(len(nonceBuffer)))
|
||||
base64.StdEncoding.Encode(a.nonce, nonceBuffer)
|
||||
|
||||
a.firstBareMsg = []byte("n=" + username + ",r=" + string(a.nonce))
|
||||
returnBytes := []byte("n,," + string(a.firstBareMsg))
|
||||
|
||||
if a.isPlus {
|
||||
bindType := "tls-unique"
|
||||
connState := a.tlsConnState
|
||||
bindData := connState.TLSUnique
|
||||
if connState.Version == tls.VersionTLS13 {
|
||||
bindType = "tls-exporter"
|
||||
bindData, err = connState.ExportKeyingMaterial("EXPORTER-Channel-Binding", []byte{}, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to export keying material: %w", err)
|
||||
}
|
||||
}
|
||||
bindData = []byte("p=" + bindType + ",," + string(bindData))
|
||||
a.bindData = make([]byte, base64.StdEncoding.EncodedLen(len(bindData)))
|
||||
base64.StdEncoding.Encode(a.bindData, bindData)
|
||||
returnBytes = []byte("p=" + bindType + ",," + string(a.firstBareMsg))
|
||||
}
|
||||
|
||||
return returnBytes, nil
|
||||
}
|
||||
|
||||
func (a *scramAuth) handleServerFirstResponse(fromServer []byte) ([]byte, error) {
|
||||
parts := bytes.Split(fromServer, []byte(","))
|
||||
if len(parts) < 3 {
|
||||
return nil, errors.New("not enough fields in the first server response")
|
||||
}
|
||||
if !bytes.HasPrefix(parts[0], []byte("r=")) {
|
||||
return nil, errors.New("first part of the server response does not start with r=")
|
||||
}
|
||||
if !bytes.HasPrefix(parts[1], []byte("s=")) {
|
||||
return nil, errors.New("second part of the server response does not start with s=")
|
||||
}
|
||||
if !bytes.HasPrefix(parts[2], []byte("i=")) {
|
||||
return nil, errors.New("third part of the server response does not start with i=")
|
||||
}
|
||||
|
||||
combinedNonce := parts[0][2:]
|
||||
if len(a.nonce) == 0 || !bytes.HasPrefix(combinedNonce, a.nonce) {
|
||||
return nil, errors.New("server nonce does not start with our nonce")
|
||||
}
|
||||
a.nonce = combinedNonce
|
||||
|
||||
encodedSalt := parts[1][2:]
|
||||
salt := make([]byte, base64.StdEncoding.DecodedLen(len(encodedSalt)))
|
||||
n, err := base64.StdEncoding.Decode(salt, encodedSalt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid encoded salt: %w", err)
|
||||
}
|
||||
salt = salt[:n]
|
||||
|
||||
iterations, err := strconv.Atoi(string(parts[2][2:]))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid iterations: %w", err)
|
||||
}
|
||||
a.iterations = iterations
|
||||
|
||||
password, err := a.normalizeString(a.password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to normalize password: %w", err)
|
||||
}
|
||||
|
||||
a.saltedPwd = pbkdf2.Key([]byte(password), salt, a.iterations, a.h().Size(), a.h)
|
||||
|
||||
msgWithoutProof := []byte("c=biws,r=" + string(a.nonce))
|
||||
if a.isPlus {
|
||||
msgWithoutProof = []byte("c=" + string(a.bindData) + ",r=" + string(a.nonce))
|
||||
}
|
||||
a.authMessage = []byte(string(a.firstBareMsg) + "," + string(fromServer) + "," + string(msgWithoutProof))
|
||||
|
||||
clientProof := a.computeClientProof()
|
||||
|
||||
return []byte(string(msgWithoutProof) + ",p=" + string(clientProof)), nil
|
||||
}
|
||||
|
||||
func (a *scramAuth) handleServerValidationMessage(fromServer []byte) ([]byte, error) {
|
||||
serverSignature := fromServer[2:]
|
||||
computedServerSignature := a.computeServerSignature()
|
||||
|
||||
if !hmac.Equal(serverSignature, computedServerSignature) {
|
||||
return nil, errors.New("invalid server signature")
|
||||
}
|
||||
return []byte(""), nil
|
||||
}
|
||||
|
||||
func (a *scramAuth) computeHMAC(key, msg []byte) []byte {
|
||||
mac := hmac.New(a.h, key)
|
||||
mac.Write(msg)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func (a *scramAuth) computeHash(key []byte) []byte {
|
||||
hasher := a.h()
|
||||
hasher.Write(key)
|
||||
return hasher.Sum(nil)
|
||||
}
|
||||
|
||||
func (a *scramAuth) computeClientProof() []byte {
|
||||
clientKey := a.computeHMAC(a.saltedPwd, []byte("Client Key"))
|
||||
storedKey := a.computeHash(clientKey)
|
||||
clientSignature := a.computeHMAC(storedKey[:], a.authMessage)
|
||||
clientProof := make([]byte, len(clientSignature))
|
||||
for i := 0; i < len(clientSignature); i++ {
|
||||
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||
}
|
||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
|
||||
base64.StdEncoding.Encode(buf, clientProof)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (a *scramAuth) computeServerSignature() []byte {
|
||||
serverKey := a.computeHMAC(a.saltedPwd, []byte("Server Key"))
|
||||
serverSignature := a.computeHMAC(serverKey, a.authMessage)
|
||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
||||
base64.StdEncoding.Encode(buf, serverSignature)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (a *scramAuth) normalizeUsername() (string, error) {
|
||||
// RFC 5802 section 5.1: the characters ',' or '=' in usernames are
|
||||
// sent as '=2C' and '=3D' respectively.
|
||||
replacer := strings.NewReplacer("=", "=3D", ",", "=2C")
|
||||
username := replacer.Replace(a.username)
|
||||
// RFC 5802 section 5.1: before sending the username to the server,
|
||||
// the client SHOULD prepare the username using the "SASLprep"
|
||||
// profile [RFC4013] of the "stringprep" algorithm [RFC3454]
|
||||
// treating it as a query string (i.e., unassigned Unicode code
|
||||
// points are allowed). If the preparation of the username fails or
|
||||
// results in an empty string, the client SHOULD abort the
|
||||
// authentication exchange.
|
||||
//
|
||||
// Since RFC 8265 obsoletes RFC 4013 we use it instead.
|
||||
username, err := a.normalizeString(username)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to normalize username: %w", err)
|
||||
}
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func (a *scramAuth) normalizeString(s string) (string, error) {
|
||||
s, err := precis.OpaqueString.String(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if s == "" {
|
||||
return "", errors.New("normalized string is empty")
|
||||
}
|
||||
return s, nil
|
||||
}
|
Loading…
Reference in a new issue