This commit is contained in:
Christian Joergensen 2014-07-15 11:16:34 +02:00
parent cda2908ec8
commit b2f59a653e
2 changed files with 77 additions and 17 deletions

@ -13,14 +13,14 @@ import (
// Server defines the parameters for running the SMTP server
type Server struct {
Addr string // Address to listen on when using ListenAndServe (default: "127.0.0.1:10025")
WelcomeMessage string // Initial server banner (default: "<hostname> ESMTP ready.")
Addr string // Address to listen on when using ListenAndServe. (default: "127.0.0.1:10025")
WelcomeMessage string // Initial server banner. (default: "<hostname> ESMTP ready.")
ReadTimeout time.Duration // Socket timeout for read operations (default: 60s)
WriteTimeout time.Duration // Socket timeout for write operations (default: 60s)
ReadTimeout time.Duration // Socket timeout for read operations. (default: 60s)
WriteTimeout time.Duration // Socket timeout for write operations. (default: 60s)
MaxMessageSize int // Max message size in bytes (default: 10240000)
MaxConnections int // Max concurrent connections, use -1 to disable (default: 100)
MaxMessageSize int // Max message size in bytes. (default: 10240000)
MaxConnections int // Max concurrent connections, use -1 to disable. (default: 100)
// New e-mails are handed off to this function.
// Can be left empty for a NOOP server.
@ -30,16 +30,18 @@ type Server struct {
// Enable various checks during the SMTP session.
// Can be left empty for no restrictions.
// If an error is returned, it will be reported in the SMTP session.
HeloChecker func(peer Peer) error // Called after HELO/EHLO.
SenderChecker func(peer Peer, addr string) error // Called after MAIL FROM.
RecipientChecker func(peer Peer, addr string) error // Called after each RCPT TO.
// Use the Error struct for access to error codes.
ConnectionChecker func(peer Peer) error // Called upon new connection.
HeloChecker func(peer Peer) error // Called after HELO/EHLO.
SenderChecker func(peer Peer, addr string) error // Called after MAIL FROM.
RecipientChecker func(peer Peer, addr string) error // Called after each RCPT TO.
// Enable PLAIN/LOGIN authentication, only available after STARTTLS.
// Can be left empty for no authentication support.
Authenticator func(peer Peer, username, password string) error
TLSConfig *tls.Config // Enable STARTTLS support
ForceTLS bool // Force STARTTLS usage
TLSConfig *tls.Config // Enable STARTTLS support.
ForceTLS bool // Force STARTTLS usage.
}
// Peer represents the client connecting to the server
@ -57,6 +59,15 @@ type Envelope struct {
Data []byte
}
// Error represents an Error reported in the SMTP session.
type Error struct {
Code int // The integer error code
Message string // The error message
}
// Error returns a string representation of the SMTP error
func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) }
type session struct {
server *Server
@ -191,7 +202,7 @@ func (session *session) serve() {
defer session.close()
session.reply(220, session.server.WelcomeMessage)
session.welcome()
for session.scanner.Scan() {
session.handle(session.scanner.Text())
@ -204,6 +215,21 @@ func (session *session) reject() {
session.close()
}
func (session *session) welcome() {
if session.server.ConnectionChecker != nil {
err := session.server.ConnectionChecker(session.peer)
if err != nil {
session.error(err)
session.close()
return
}
}
session.reply(220, session.server.WelcomeMessage)
}
func (session *session) reply(code int, message string) {
fmt.Fprintf(session.writer, "%d %s\r\n", code, message)
@ -216,7 +242,11 @@ func (session *session) reply(code int, message string) {
}
func (session *session) error(err error) {
session.reply(502, fmt.Sprintf("%s", err))
if smtpdError, ok := err.(Error); ok {
session.reply(smtpdError.Code, smtpdError.Message)
} else {
session.reply(502, fmt.Sprintf("%s", err))
}
}
func (session *session) extensions() []string {

@ -2,7 +2,6 @@ package smtpd
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/smtp"
@ -195,6 +194,31 @@ func TestSTARTTLS(t *testing.T) {
}
}
func TestConnectionCheck(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen failed: %v", err)
}
defer ln.Close()
server := &Server{
ConnectionChecker: func(peer Peer) error {
return Error{Code: 552, Message: "Denied"}
},
}
go func() {
server.Serve(ln)
}()
if _, err := smtp.Dial(ln.Addr().String()); err == nil {
t.Fatal("Dial succeeded despite ConnectionCheck")
}
}
func TestHELOCheck(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
@ -205,7 +229,9 @@ func TestHELOCheck(t *testing.T) {
defer ln.Close()
server := &Server{
HeloChecker: func(peer Peer) error { return errors.New("Denied") },
HeloChecker: func(peer Peer) error {
return Error{Code: 552, Message: "Denied"}
},
}
go func() {
@ -233,7 +259,9 @@ func TestSenderCheck(t *testing.T) {
defer ln.Close()
server := &Server{
SenderChecker: func(peer Peer, addr string) error { return errors.New("Denied") },
SenderChecker: func(peer Peer, addr string) error {
return Error{Code: 552, Message: "Denied"}
},
}
go func() {
@ -261,7 +289,9 @@ func TestRecipientCheck(t *testing.T) {
defer ln.Close()
server := &Server{
RecipientChecker: func(peer Peer, addr string) error { return errors.New("Denied") },
RecipientChecker: func(peer Peer, addr string) error {
return Error{Code: 552, Message: "Denied"}
},
}
go func() {