From 3127bd4ed806d0a77f4210d390c6ce2ef5e8073b Mon Sep 17 00:00:00 2001 From: Christian Joergensen Date: Mon, 14 Jul 2014 13:55:41 +0200 Subject: [PATCH] Refactor. --- address.go | 15 +++ cmd/smtpd/main.go | 4 +- protocol.go | 236 +++++++++++++++++++++++++++++++++++ smtpd.go | 309 ++++++++++++++++------------------------------ 4 files changed, 358 insertions(+), 206 deletions(-) create mode 100644 address.go create mode 100644 protocol.go diff --git a/address.go b/address.go new file mode 100644 index 0000000..b663537 --- /dev/null +++ b/address.go @@ -0,0 +1,15 @@ +package smtpd + +import ( + "strings" + "fmt" +) + +type MailAddress string + +func parseMailAddress(src string) (MailAddress, error) { + if src[0] != '<' || src[len(src)-1] != '>' || strings.Count(src, "@") != 1 { + return MailAddress(""), fmt.Errorf("Ill-formatted e-mail address: %s", src) + } + return MailAddress(src[1 : len(src)-1]), nil +} diff --git a/cmd/smtpd/main.go b/cmd/smtpd/main.go index a39068f..350f6f3 100644 --- a/cmd/smtpd/main.go +++ b/cmd/smtpd/main.go @@ -8,7 +8,7 @@ import ( ) func dumpMessage(peer smtpd.Peer, env smtpd.Envelope) error { - log.Printf("New mail from: %s", env.MailFrom) + log.Printf("New mail from: %s", env.Sender) return nil } @@ -32,8 +32,6 @@ func main() { } server := &smtpd.Server{ - Addr: "127.0.0.1:10025", - WelcomeMessage: "localhost ESMTP ready.", Handler: dumpMessage, TLSConfig: tlsConfig, ForceTLS: true, diff --git a/protocol.go b/protocol.go new file mode 100644 index 0000000..a647d9f --- /dev/null +++ b/protocol.go @@ -0,0 +1,236 @@ +package smtpd + +import ( + "fmt" + "strings" + "crypto/tls" + "bufio" + "log" + "bytes" +) + +type command struct { + line string + action string + fields []string + params []string +} + +func parseLine(line string) (cmd command) { + + cmd.line = line + cmd.fields = strings.Fields(line) + cmd.action = strings.ToUpper(cmd.fields[0]) + + if len(cmd.fields) > 1 { + cmd.params = strings.Split(cmd.fields[1], ":") + } + + return + +} + +func (session *session) handleHELO(cmd command) { + + if len(cmd.fields) < 2 { + session.reply(502, "Missing parameter") + return + } + + session.peer.HeloName = cmd.fields[1] + + if session.server.HeloChecker != nil { + err := session.server.HeloChecker(session.peer) + if err != nil { + session.error(err) + session.close() + return + } + } + + session.reply(250, "Go ahead") + + return + +} + +func (session *session) handleEHLO(cmd command) { + + if len(cmd.fields) < 2 { + session.reply(502, "Missing parameter") + return + } + + session.peer.HeloName = cmd.fields[1] + + if session.server.HeloChecker != nil { + err := session.server.HeloChecker(session.peer) + if err != nil { + session.error(err) + session.close() + return + } + } + + extensions := session.extensions() + + if len(extensions) > 1 { + for _, ext := range extensions[:len(extensions)-1] { + fmt.Fprintf(session.writer, "250-%s\r\n", ext) + } + } + + session.reply(250, extensions[len(extensions)-1]) + + return + +} + +func (session *session) handleMAIL(cmd command) { + + if session.peer.HeloName == "" { + session.reply(502, "Please introduce yourself first.") + return + } + + addr, err := parseMailAddress(cmd.params[1]) + + if err != nil { + session.reply(502, "Ill-formatted e-mail address") + return + } + + if session.server.SenderChecker != nil { + err = session.server.SenderChecker(session.peer, addr) + session.error(err) + return + } + + session.envelope = &Envelope{ + Sender: addr, + } + + session.reply(250, "Go ahead") + + return + +} + +func (session *session) handleRCPT(cmd command) { + + if session.envelope == nil { + session.reply(502, "Missing MAIL FROM command.") + return + } + + addr, err := parseMailAddress(cmd.params[1]) + + if err != nil { + session.reply(502, "Ill-formatted e-mail address") + return + } + + if session.server.RecipientChecker != nil { + err = session.server.RecipientChecker(session.peer, addr) + session.error(err) + return + } + + session.envelope.Recipients = append(session.envelope.Recipients, addr) + + session.reply(250, "Go ahead") + + return + +} + + +func (session *session) handleSTARTTLS(cmd command) { + + if session.tls { + session.reply(502, "Already running in TLS") + return + } + + if session.server.TLSConfig == nil { + session.reply(502, "TLS not supported") + return + } + + tls_conn := tls.Server(session.conn, session.server.TLSConfig) + session.reply(250, "Go ahead") + + if err := tls_conn.Handshake(); err != nil { + log.Printf("TLS Handshake error:", err) + session.reply(550, "Handshake error") + return + } + + session.conn = tls_conn + session.reader = bufio.NewReader(tls_conn) + session.writer = bufio.NewWriter(tls_conn) + session.scanner = bufio.NewScanner(session.reader) + session.tls = true + + return + +} + +func (session *session) handleDATA(cmd command) { + + if session.envelope == nil || len(session.envelope.Recipients) == 0 { + session.reply(502, "Missing RCPT TO command.") + return + } + + session.reply(250, "Go ahead. End your data with .") + + data := &bytes.Buffer{} + done := false + + for session.scanner.Scan() { + + line := session.scanner.Text() + + if line == "." { + done = true + break + } + + data.Write([]byte(line)) + data.Write([]byte("\r\n")) + + } + + if !done { + return + } + + session.envelope.Data = data.Bytes() + + err := session.deliver() + + if err != nil { + session.error(err) + } else { + session.reply(200, "Thank you.") + } + +} + +func (session *session) handleRSET(cmd command) { + session.envelope = nil + session.reply(250, "Go ahead") + return +} + +func (session *session) handleNOOP(cmd command) { + session.reply(250, "Go ahead") + return +} + +func (session *session) handleQUIT(cmd command) { + session.reply(250, "OK, bye") + session.close() + return +} diff --git a/smtpd.go b/smtpd.go index 73994e7..da6ee80 100644 --- a/smtpd.go +++ b/smtpd.go @@ -1,28 +1,38 @@ +// Package smtpd implements a SMTP server with support for STARTTLS, authentication and restrictions on the different stages of the SMTP session. package smtpd import ( "bufio" - "bytes" "crypto/tls" "fmt" "log" "net" - "strings" "time" + "os" ) type Server struct { - Addr string // Address to listen on - WelcomeMessage string // Initial server banner + + Addr string // Address to listen on when using ListenAndServe (default: "127.0.0.1:10025") + WelcomeMessage string // Initial server banner (default: " ESMTP ready.") ReadTimeout time.Duration // Socket timeout for read operations (default: 60s) WriteTimeout time.Duration // Socket timeout for write operations (default: 60s) // New e-mails are handed off to this function. - // If an error is returned, it will be reported in the SMTP session + // Can be left empty for a NOOP server. + // If an error is returned, it will be reported in the SMTP session. Handler func(peer Peer, env Envelope) error - // Enable PLAIN/LOGIN authentication + // Enable various checks during the SMTP session. + // Can be left empty for no restrictions. + // If an error is returned, it will be reported in the SMTP session. + HeloChecker func(peer Peer) error // Called after HELO/EHLO. + SenderChecker func(peer Peer, addr MailAddress) error // Called after MAIL FROM. + RecipientChecker func(peer Peer, addr MailAddress) error // Called after each RCPT TO. + + // Enable PLAIN/LOGIN authentication, only available after STARTTLS. + // Can be left empty for no authentication support. Authenticator func(peer Peer, username, password string) error TLSConfig *tls.Config // Enable STARTTLS support @@ -31,42 +41,36 @@ type Server struct { MaxMessageSize int // Max message size in bytes (default: 10240000) } -type sessionState int - -const ( - _STATE_HELO sessionState = iota - _STATE_AUTH - _STATE_MAIL - _STATE_RCPT - _STATE_DATA -) - -type session struct { - server *Server - conn net.Conn - reader *bufio.Reader - writer *bufio.Writer - peer Peer - state sessionState - tls bool -} - type Peer struct { HeloName string // Server name used in HELO/EHLO command - UserName string // Username from authentication + Username string // Username from authentication + Password string // Password from authentication Addr net.Addr // Network address } -type MailAddress string - type Envelope struct { - MailFrom MailAddress + Sender MailAddress Recipients []MailAddress Data []byte - Peer *Peer } -func (srv *Server) newConnection(c net.Conn) (s *session, err error) { +type session struct { + + server *Server + + peer Peer + envelope *Envelope + + conn net.Conn + + reader *bufio.Reader + writer *bufio.Writer + scanner *bufio.Scanner + + tls bool +} + +func (srv *Server) newSession(c net.Conn) (s *session, err error) { log.Printf("New connection from: %s", c.RemoteAddr()) @@ -77,16 +81,22 @@ func (srv *Server) newConnection(c net.Conn) (s *session, err error) { writer: bufio.NewWriter(c), peer: Peer{Addr: c.RemoteAddr()}, } + + s.scanner = bufio.NewScanner(s.reader) return s, nil } func (srv *Server) ListenAndServe() error { + + srv.configureDefaults() + l, err := net.Listen("tcp", srv.Addr) if err != nil { return err } + log.Printf("Listening on: %s", srv.Addr) return srv.Serve(l) } @@ -107,13 +117,11 @@ func (srv *Server) Serve(l net.Listener) error { return e } - session, err := srv.newConnection(conn) + session, err := srv.newSession(conn) if err != nil { continue } - session.state = _STATE_HELO - go session.serve() } @@ -138,181 +146,79 @@ func (srv *Server) configureDefaults() { log.Fatal("Cannot use ForceTLS with no TLSConfig") } + if srv.Addr == "" { + srv.Addr = "127.0.0.1:10025" + } + + if srv.WelcomeMessage == "" { + + hostname, err := os.Hostname() + + if err != nil { + log.Fatal("Couldn't determine hostname: %s", err) + } + + srv.WelcomeMessage = fmt.Sprintf("%s ESMTP ready.", hostname) + + } + } func (session *session) serve() { log.Print("Serving") - defer func() { - session.writer.Flush() - session.conn.Close() - }() + defer session.close() session.reply(220, session.server.WelcomeMessage) - scanner := bufio.NewScanner(session.reader) + for session.scanner.Scan() { - var env Envelope - var data *bytes.Buffer + line := session.scanner.Text() + cmd := parseLine(line) - for scanner.Scan() { - - line := scanner.Text() - command := "" - fields := []string{} - params := []string{} - - if session.state != _STATE_DATA { - fields = strings.Fields(line) - command = strings.ToUpper(fields[0]) - if len(fields) > 1 { - params = strings.Split(fields[1], ":") - } - } - - log.Printf("Line: %s, fields: %#v, params: %#v", line, fields, params) - - if command == "QUIT" { - session.reply(250, "Ok, bye") - return - } - - switch session.state { - - case _STATE_HELO: - - if command == "HELO" || command == "EHLO" { - if len(fields) < 2 { - session.reply(502, "Missing parameter") - continue - } else { - session.peer.HeloName = fields[1] - } - } else { - session.reply(502, "Command not recognized, expected HELO/EHLO") - continue - } - - if command == "EHLO" { - session.WriteExtensions() - } else { - session.reply(250, "Go ahead") - } - - if session.server.Authenticator == nil { - session.state = _STATE_MAIL - } else { - session.state = _STATE_AUTH - } + switch cmd.action { + case "HELO": + session.handleHELO(cmd) continue - case _STATE_MAIL: - - if !session.tls && command == "STARTTLS" && session.server.TLSConfig != nil { - - tls_conn := tls.Server(session.conn, session.server.TLSConfig) - session.reply(250, "Go ahead") - - if err := tls_conn.Handshake(); err != nil { - log.Printf("TLS Handshake error:", err) - session.reply(550, "Handshake error") - continue - } - - session.conn = tls_conn - - session.reader = bufio.NewReader(tls_conn) - session.writer = bufio.NewWriter(tls_conn) - - scanner = bufio.NewScanner(session.reader) - - session.tls = true - session.state = _STATE_HELO - - continue - - } - - if !session.tls && session.server.ForceTLS { - session.reply(550, "Must run STARTTLS first") - continue - } - - if command == "MAIL" && strings.ToUpper(params[0]) == "FROM" { - - addr, err := parseMailAddress(params[1]) - - if err != nil { - session.reply(502, "Ill-formatted e-mail address") - continue - } - - env = Envelope{ - Peer: &session.peer, - MailFrom: addr, - } - - session.reply(250, "Go ahead") - session.state = _STATE_RCPT - continue - - } else { - session.reply(502, "Command not recognized, expected MAIL FROM") - continue - } - - case _STATE_RCPT: - - if command == "RCPT" && strings.ToUpper(params[0]) == "TO" { - - addr, err := parseMailAddress(params[1]) - - if err != nil { - session.reply(502, "Ill-formatted e-mail address") - continue - } - - env.Recipients = append(env.Recipients, addr) - - session.reply(250, "Go ahead") - continue - - } else if command == "DATA" && len(env.Recipients) > 0 { - session.reply(250, "Go ahead. End your data with .") - data = &bytes.Buffer{} - session.state = _STATE_DATA - continue - } - - if len(env.Recipients) == 0 { - session.reply(502, "Command not recognized, expected RCPT") - } else { - session.reply(502, "Command not recognized, expected RCPT or DATA") - } - + case "EHLO": + session.handleEHLO(cmd) continue - case _STATE_DATA: + case "MAIL": + session.handleMAIL(cmd) + continue - if line == "." { - env.Data = data.Bytes() - data.Reset() - err := session.handle(env) + case "RCPT": + session.handleRCPT(cmd) + continue - if err != nil { - session.reply(502, fmt.Sprintf("%s", err)) - } else { - session.reply(200, "Thank you.") - } + case "STARTTLS": + session.handleSTARTTLS(cmd) + continue - session.state = _STATE_MAIL - continue - } + case "DATA": + session.handleDATA(cmd) + continue + + case "RSET": + session.handleRSET(cmd) + continue + + case "NOOP": + session.handleNOOP(cmd) + continue + + case "QUIT": + session.handleQUIT(cmd) + continue } + session.reply(502, "Unsupported command.") + } } @@ -328,7 +234,12 @@ func (session *session) reply(code int, message string) { } -func (session *session) WriteExtensions() { +func (session *session) error(err error) { + session.reply(502, fmt.Sprintf("%s", err)) +} + + +func (session *session) extensions() []string { extensions := []string{ "SIZE 10240000", @@ -342,27 +253,19 @@ func (session *session) WriteExtensions() { extensions = append(extensions, "AUTH PLAIN LOGIN") } - if len(extensions) > 1 { - for _, ext := range extensions[:len(extensions)-1] { - fmt.Fprintf(session.writer, "250-%s\r\n", ext) - } - } - - session.reply(250, extensions[len(extensions)-1]) + return extensions } -func (session *session) handle(env Envelope) error { +func (session *session) deliver() error { if session.server.Handler != nil { - return session.server.Handler(session.peer, env) + return session.server.Handler(session.peer, *session.envelope) } else { return nil } } -func parseMailAddress(src string) (MailAddress, error) { - if src[0] != '<' || src[len(src)-1] != '>' || strings.Count(src, "@") != 1 { - return MailAddress(""), fmt.Errorf("Ill-formatted e-mail address: %s", src) - } - return MailAddress(src[1 : len(src)-1]), nil +func (session *session) close() { + session.writer.Flush() + session.conn.Close() }