Better context and connection handling

This commit is contained in:
Winni Neessen 2022-03-10 12:10:27 +01:00
parent 4babc309fb
commit 59a1d14ca7
Signed by: wneessen
GPG key ID: 385AC9889632126E
4 changed files with 83 additions and 34 deletions

View file

@ -4,10 +4,8 @@
<option name="autoReloadType" value="ALL" /> <option name="autoReloadType" value="ALL" />
</component> </component>
<component name="ChangeListManager"> <component name="ChangeListManager">
<list default="true" id="b79e8e7a-d892-4ce4-8bf4-f9e45415b803" name="Changes" comment="Progress"> <list default="true" id="b79e8e7a-d892-4ce4-8bf4-f9e45415b803" name="Changes" comment="Implemented SMTP AUTH">
<change afterPath="$PROJECT_DIR$/auth.go" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/README.md" beforeDir="false" afterPath="$PROJECT_DIR$/README.md" afterDir="false" />
<change beforePath="$PROJECT_DIR$/client.go" beforeDir="false" afterPath="$PROJECT_DIR$/client.go" afterDir="false" /> <change beforePath="$PROJECT_DIR$/client.go" beforeDir="false" afterPath="$PROJECT_DIR$/client.go" afterDir="false" />
<change beforePath="$PROJECT_DIR$/cmd/main.go" beforeDir="false" afterPath="$PROJECT_DIR$/cmd/main.go" afterDir="false" /> <change beforePath="$PROJECT_DIR$/cmd/main.go" beforeDir="false" afterPath="$PROJECT_DIR$/cmd/main.go" afterDir="false" />
<change beforePath="$PROJECT_DIR$/mailmsg.go" beforeDir="false" afterPath="$PROJECT_DIR$/mailmsg.go" afterDir="false" /> <change beforePath="$PROJECT_DIR$/mailmsg.go" beforeDir="false" afterPath="$PROJECT_DIR$/mailmsg.go" afterDir="false" />
@ -95,7 +93,8 @@
</component> </component>
<component name="VcsManagerConfiguration"> <component name="VcsManagerConfiguration">
<MESSAGE value="Progress" /> <MESSAGE value="Progress" />
<option name="LAST_COMMIT_MESSAGE" value="Progress" /> <MESSAGE value="Implemented SMTP AUTH" />
<option name="LAST_COMMIT_MESSAGE" value="Implemented SMTP AUTH" />
</component> </component>
<component name="VgoProject"> <component name="VgoProject">
<integration-enabled>true</integration-enabled> <integration-enabled>true</integration-enabled>

View file

@ -53,10 +53,13 @@ type Client struct {
// satype represents the authentication type for SMTP AUTH // satype represents the authentication type for SMTP AUTH
satype SMTPAuthType satype SMTPAuthType
// smtpauth is a pointer to smtp.Auth // co is the net.Conn that the smtp.Client is based on
co net.Conn
// sa is a pointer to smtp.Auth
sa smtp.Auth sa smtp.Auth
// The SMTP client that is set up when using the Dial*() methods // sc is the smtp.Client that is set up when using the Dial*() methods
sc *smtp.Client sc *smtp.Client
} }
@ -66,6 +69,13 @@ type Option func(*Client)
var ( var (
// ErrNoHostname should be used if a Client has no hostname set // ErrNoHostname should be used if a Client has no hostname set
ErrNoHostname = errors.New("hostname for client cannot be empty") ErrNoHostname = errors.New("hostname for client cannot be empty")
// ErrDeadlineExtendFailed should be used if the extension of the connection deadline fails
ErrDeadlineExtendFailed = errors.New("connection deadline extension failed")
// ErrNoActiveConnection should be used when a method is used that requies a server connection
// but is not yet connected
ErrNoActiveConnection = errors.New("not connected to SMTP server")
) )
// NewClient returns a new Session client object // NewClient returns a new Session client object
@ -190,9 +200,24 @@ func (c *Client) SetTLSConfig(co *tls.Config) {
c.tlsconfig = co c.tlsconfig = co
} }
// Send sends out the mail message // SetUsername overrides the current username string with the given value
func (c *Client) Send() error { func (c *Client) SetUsername(u string) {
return nil c.user = u
}
// SetPassword overrides the current password string with the given value
func (c *Client) SetPassword(p string) {
c.pass = p
}
// SetSMTPAuth overrides the current SMTP AUTH type setting with the given value
func (c *Client) SetSMTPAuth(a SMTPAuthType) {
c.satype = a
}
// SetSMTPAuthCustom overrides the current SMTP AUTH setting with the given custom smtp.Auth
func (c *Client) SetSMTPAuthCustom(sa smtp.Auth) {
c.sa = sa
} }
// Close closes the connection cto the SMTP server // Close closes the connection cto the SMTP server
@ -210,33 +235,26 @@ func (c *Client) setDefaultHelo() error {
return nil return nil
} }
// Dial establishes a connection cto the SMTP server with a default context.Background
func (c *Client) Dial() error {
ctx := context.Background()
return c.DialWithContext(ctx)
}
// DialWithContext establishes a connection cto the SMTP server with a given context.Context // DialWithContext establishes a connection cto the SMTP server with a given context.Context
func (c *Client) DialWithContext(uctx context.Context) error { func (c *Client) DialWithContext(pc context.Context) error {
ctx, cfn := context.WithTimeout(uctx, c.cto) ctx, cfn := context.WithDeadline(pc, time.Now().Add(c.cto))
defer cfn() defer cfn()
nd := net.Dialer{} nd := net.Dialer{}
td := tls.Dialer{} td := tls.Dialer{}
var co net.Conn
var err error var err error
if c.ssl { if c.ssl {
c.enc = true c.enc = true
co, err = td.DialContext(ctx, "tcp", c.ServerAddr()) c.co, err = td.DialContext(ctx, "tcp", c.ServerAddr())
} }
if !c.ssl { if !c.ssl {
co, err = nd.DialContext(ctx, "tcp", c.ServerAddr()) c.co, err = nd.DialContext(ctx, "tcp", c.ServerAddr())
} }
if err != nil { if err != nil {
return err return err
} }
c.sc, err = smtp.NewClient(co, c.host) c.sc, err = smtp.NewClient(c.co, c.host)
if err != nil { if err != nil {
return err return err
} }
@ -274,8 +292,45 @@ func (c *Client) DialWithContext(uctx context.Context) error {
return nil return nil
} }
// Send sends out the mail message
func (c *Client) Send() error {
if err := c.checkConn(); err != nil {
return fmt.Errorf("failed to send mail: %w", err)
}
return nil
}
// DialAndSend establishes a connection to the SMTP server with a
// default context.Background and sends the mail
func (c *Client) DialAndSend() error {
ctx := context.Background()
if err := c.DialWithContext(ctx); err != nil {
return fmt.Errorf("dial failed: %w", err)
}
if err := c.Send(); err != nil {
return fmt.Errorf("send failed: %w", err)
}
return nil
}
// checkConn makes sure that a required server connection is available and extends the
// connection deadline
func (c *Client) checkConn() error {
if c.co == nil {
return ErrNoActiveConnection
}
if err := c.co.SetDeadline(time.Now().Add(c.cto)); err != nil {
return ErrDeadlineExtendFailed
}
return nil
}
// auth will try to perform SMTP AUTH if requested // auth will try to perform SMTP AUTH if requested
func (c *Client) auth() error { func (c *Client) auth() error {
if err := c.checkConn(); err != nil {
return fmt.Errorf("failed to authenticate: %w", err)
}
if c.sa == nil && c.satype != "" { if c.sa == nil && c.satype != "" {
sa, sat := c.sc.Extension("AUTH") sa, sat := c.sc.Extension("AUTH")
if !sa { if !sa {

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"github.com/wneessen/go-mail" "github.com/wneessen/go-mail"
"os" "os"
@ -14,24 +13,15 @@ func main() {
fmt.Printf("$TEST_HOST env variable cannot be empty\n") fmt.Printf("$TEST_HOST env variable cannot be empty\n")
os.Exit(1) os.Exit(1)
} }
tu := os.Getenv("TEST_USER") tu := os.Getenv("TEST_USER")
tp := os.Getenv("TEST_PASS") tp := os.Getenv("TEST_PASS")
c, err := mail.NewClient(th, mail.WithTimeout(time.Millisecond*500), mail.WithTLSPolicy(mail.TLSMandatory), c, err := mail.NewClient(th, mail.WithTimeout(time.Millisecond*500), mail.WithTLSPolicy(mail.TLSMandatory),
mail.WithSMTPAuth(mail.SMTPAuthDigestMD5), mail.WithUsername(tu), mail.WithPassword(tp)) mail.WithSMTPAuth(mail.SMTPAuthDigestMD5), mail.WithUsername(tu), mail.WithPassword(tp))
if err != nil { if err != nil {
fmt.Printf("failed to create new client: %s\n", err) fmt.Printf("failed to create new client: %s\n", err)
os.Exit(1) os.Exit(1)
} }
//c.SetTLSPolicy(mail.TLSMandatory)
ctx, cfn := context.WithCancel(context.Background())
defer cfn()
if err := c.DialWithContext(ctx); err != nil {
fmt.Printf("failed to dial: %s\n", err)
os.Exit(1)
}
m := mail.NewMsg() m := mail.NewMsg()
m.From("wn@neessen.net") m.From("wn@neessen.net")
@ -41,4 +31,8 @@ func main() {
m.SetBulk() m.SetBulk()
m.Header() m.Header()
if err := c.DialAndSend(); err != nil {
fmt.Printf("failed to dial: %s\n", err)
os.Exit(1)
}
} }

View file

@ -73,14 +73,15 @@ func (m *Msg) SetMessageIDWithValue(v string) {
m.SetHeader(HeaderMessageID, v) m.SetHeader(HeaderMessageID, v)
} }
// SetBulk sets the "Precedense: bulk" header which is recommended for // SetBulk sets the "Precedence: bulk" header which is recommended for
// automated mails like OOO replies // automated mails like OOO replies
// See: https://www.rfc-editor.org/rfc/rfc2076#section-3.9 // See: https://www.rfc-editor.org/rfc/rfc2076#section-3.9
func (m *Msg) SetBulk() { func (m *Msg) SetBulk() {
m.SetHeader(HeaderPrecedence, "bulk") m.SetHeader(HeaderPrecedence, "bulk")
} }
// Header FixMe // Header does something
// FIXME: This is only here to quickly show the set headers for debugging purpose. Remove me later
func (m *Msg) Header() { func (m *Msg) Header() {
fmt.Printf("%+v\n", m.header) fmt.Printf("%+v\n", m.header)