From 66f94a07ae82e456979b1e842ff256877cd03bd7 Mon Sep 17 00:00:00 2001 From: Craig Date: Fri, 12 Feb 2016 01:26:50 +0000 Subject: [PATCH] Fix several panics on invalid input --- .hgignore | 3 + LICENSE | 20 + README.md | 19 + address.go | 19 + envelope.go | 54 ++ example_test.go | 46 ++ examples/dkim-proxy/main.go | 80 +++ protocol.go | 582 ++++++++++++++++++ smtpd.go | 324 ++++++++++ smtpd_test.go | 1162 +++++++++++++++++++++++++++++++++++ wrap.go | 22 + wrap_test.go | 24 + 12 files changed, 2355 insertions(+) create mode 100644 .hgignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 address.go create mode 100644 envelope.go create mode 100644 example_test.go create mode 100644 examples/dkim-proxy/main.go create mode 100644 protocol.go create mode 100644 smtpd.go create mode 100644 smtpd_test.go create mode 100644 wrap.go create mode 100644 wrap_test.go diff --git a/.hgignore b/.hgignore new file mode 100644 index 0000000..758d190 --- /dev/null +++ b/.hgignore @@ -0,0 +1,3 @@ +syntax: glob + +*.orig diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f8c34e2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2014 Christian Joergensen (christian@technobabble.dk) + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..a428fea --- /dev/null +++ b/README.md @@ -0,0 +1,19 @@ +Go smtpd [![GoDoc](https://godoc.org/bitbucket.org/chrj/smtpd?status.png)](https://godoc.org/bitbucket.org/chrj/smtpd) +======== + +Package smtpd implements an SMTP server in golang. + +Features +-------- + +* STARTTLS (using `crypto/tls`) +* Authentication (PLAIN/LOGIN, only after STARTTLS) +* XCLIENT (for running behind a proxy) +* Connection, HELO, sender and recipient checks for rejecting e-mails using callbacks +* Configurable limits for: connection count, message size and recipient count +* Hands incoming e-mail off to a configured callback function + +Feedback +-------- + +If you end up using this package or have any feedback, I'd very much like to hear about it. You can reach me by [email](mailto:christian@technobabble.dk). diff --git a/address.go b/address.go new file mode 100644 index 0000000..68ae8e6 --- /dev/null +++ b/address.go @@ -0,0 +1,19 @@ +package smtpd + +import ( + "fmt" + "strings" +) + +func parseAddress(src string) (string, error) { + + if len(src) == 0 || src[0] != '<' || src[len(src)-1] != '>' { + return "", fmt.Errorf("Ill-formatted e-mail address: %s", src) + } + + if strings.Count(src, "@") > 1 { + return "", fmt.Errorf("Ill-formatted e-mail address: %s", src) + } + + return src[1 : len(src)-1], nil +} diff --git a/envelope.go b/envelope.go new file mode 100644 index 0000000..0e0f1bc --- /dev/null +++ b/envelope.go @@ -0,0 +1,54 @@ +package smtpd + +import ( + "crypto/tls" + "fmt" + "strings" + "time" +) + +// Envelope holds a message +type Envelope struct { + Sender string + Recipients []string + Data []byte +} + +// AddReceivedLine prepends a Received header to the Data +func (env *Envelope) AddReceivedLine(peer Peer) { + + tlsDetails := "" + + tlsVersions := map[uint16]string{ + tls.VersionSSL30: "SSL3.0", + tls.VersionTLS10: "TLS1.0", + tls.VersionTLS11: "TLS1.1", + tls.VersionTLS12: "TLS1.2", + } + + if peer.TLS != nil { + tlsDetails = fmt.Sprintf( + "\r\n\t(version=%s cipher=0x%x);", + tlsVersions[peer.TLS.Version], + peer.TLS.CipherSuite, + ) + } + + line := wrap([]byte(fmt.Sprintf( + "Received: from %s [%s] by %s with %s;%s\r\n\t%s\r\n", + peer.HeloName, + strings.Split(peer.Addr.String(), ":")[0], + peer.ServerName, + peer.Protocol, + tlsDetails, + time.Now().Format("Mon Jan 2 15:04:05 -0700 2006"), + ))) + + env.Data = append(env.Data, line...) + + // Move the new Received line up front + + copy(env.Data[len(line):], env.Data[0:len(env.Data)-len(line)]) + copy(env.Data, line) + +} diff --git a/example_test.go b/example_test.go new file mode 100644 index 0000000..bdfe3f5 --- /dev/null +++ b/example_test.go @@ -0,0 +1,46 @@ +package smtpd_test + +import ( + "bitbucket.org/chrj/smtpd" + "errors" + "net/smtp" + "strings" +) + +func ExampleServer() { + var server *smtpd.Server + + // No-op server. Accepts and discards + server = &smtpd.Server{} + server.ListenAndServe("127.0.0.1:10025") + + // Relay server. Accepts only from single IP address and forwards using the Gmail smtp + server = &smtpd.Server{ + + HeloChecker: func(peer smtpd.Peer, name string) error { + if !strings.HasPrefix(peer.Addr.String(), "42.42.42.42:") { + return errors.New("Denied") + } + return nil + }, + + Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + + return smtp.SendMail( + "smtp.gmail.com:587", + smtp.PlainAuth( + "", + "username@gmail.com", + "password", + "smtp.gmail.com", + ), + env.Sender, + env.Recipients, + env.Data, + ) + + }, + } + + server.ListenAndServe("127.0.0.1:10025") +} diff --git a/examples/dkim-proxy/main.go b/examples/dkim-proxy/main.go new file mode 100644 index 0000000..b9a8f45 --- /dev/null +++ b/examples/dkim-proxy/main.go @@ -0,0 +1,80 @@ +// Command dkim-proxy implements a simple SMTP proxy that DKIM signs incoming e-mail and relays to another SMTP server for delivery +package main + +import ( + "bytes" + "flag" + "io/ioutil" + "log" + "net/smtp" + + "bitbucket.org/chrj/smtpd" + "github.com/eaigner/dkim" +) + +var ( + welcomeMsg = flag.String("welcome", "DKIM-proxy ESMTP ready.", "Welcome message for SMTP session") + inAddr = flag.String("inaddr", "localhost:10025", "Address to listen for incoming SMTP on") + outAddr = flag.String("outaddr", "localhost:25", "Address to deliver outgoing SMTP on") + privKeyFile = flag.String("key", "", "Private key file.") + dkimS = flag.String("s", "default", "DKIM selector") + dkimD = flag.String("d", "", "DKIM domain") + + dkimConf dkim.Conf + privKey []byte +) + +func handler(peer smtpd.Peer, env smtpd.Envelope) error { + + d, err := dkim.New(dkimConf, privKey) + if err != nil { + log.Printf("DKIM error: %v", err) + return smtpd.Error{450, "Internal server error"} + } + + // The dkim package expects \r\n newlines, so replace to that + data, err := d.Sign(bytes.Replace(env.Data, []byte("\n"), []byte("\r\n"), -1)) + if err != nil { + log.Printf("DKIM signing error: %v", err) + return smtpd.Error{450, "Internal server error"} + } + + return smtp.SendMail( + *outAddr, + nil, + env.Sender, + env.Recipients, + data, + ) + +} + +func main() { + + flag.Parse() + + var err error + + dkimConf, err = dkim.NewConf(*dkimD, *dkimS) + if err != nil { + log.Fatalf("DKIM configuration error: %v", err) + } + + privKey, err = ioutil.ReadFile(*privKeyFile) + if err != nil { + log.Fatalf("Couldn't read private key: %v", err) + } + + _, err = dkim.New(dkimConf, privKey) + if err != nil { + log.Fatalf("DKIM error: %v", err) + } + + server := &smtpd.Server{ + WelcomeMessage: *welcomeMsg, + Handler: handler, + } + + server.ListenAndServe(*inAddr) + +} diff --git a/protocol.go b/protocol.go new file mode 100644 index 0000000..6dfea11 --- /dev/null +++ b/protocol.go @@ -0,0 +1,582 @@ +package smtpd + +import ( + "bufio" + "bytes" + "crypto/tls" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "net" + "net/textproto" + "strconv" + "strings" + "time" +) + +type command struct { + line string + action string + fields []string + params []string +} + +func parseLine(line string) (cmd command) { + + cmd.line = line + cmd.fields = strings.Fields(line) + + if len(cmd.fields) > 0 { + cmd.action = strings.ToUpper(cmd.fields[0]) + if len(cmd.fields) > 1 { + cmd.params = strings.Split(cmd.fields[1], ":") + } + } + + return + +} + +func (session *session) handle(line string) { + + cmd := parseLine(line) + + // Commands are dispatched to the appropriate handler functions. + // If a network error occurs during handling, the handler should + // just return and let the error be handled on the next read. + + switch cmd.action { + + case "HELO": + session.handleHELO(cmd) + return + + case "EHLO": + session.handleEHLO(cmd) + return + + case "MAIL": + session.handleMAIL(cmd) + return + + case "RCPT": + session.handleRCPT(cmd) + return + + case "STARTTLS": + session.handleSTARTTLS(cmd) + return + + case "DATA": + session.handleDATA(cmd) + return + + case "RSET": + session.handleRSET(cmd) + return + + case "NOOP": + session.handleNOOP(cmd) + return + + case "QUIT": + session.handleQUIT(cmd) + return + + case "AUTH": + session.handleAUTH(cmd) + return + + case "XCLIENT": + session.handleXCLIENT(cmd) + return + + } + + session.reply(502, "Unsupported command.") + +} + +func (session *session) handleHELO(cmd command) { + + if len(cmd.fields) < 2 { + session.reply(502, "Missing parameter") + return + } + + if session.peer.HeloName != "" { + // Reset envelope in case of duplicate HELO + session.reset() + } + + if session.server.HeloChecker != nil { + err := session.server.HeloChecker(session.peer, cmd.fields[1]) + if err != nil { + session.error(err) + return + } + } + + session.peer.HeloName = cmd.fields[1] + session.peer.Protocol = SMTP + session.reply(250, "Go ahead") + + return + +} + +func (session *session) handleEHLO(cmd command) { + + if len(cmd.fields) < 2 { + session.reply(502, "Missing parameter") + return + } + + if session.peer.HeloName != "" { + // Reset envelope in case of duplicate EHLO + session.reset() + } + + if session.server.HeloChecker != nil { + err := session.server.HeloChecker(session.peer, cmd.fields[1]) + if err != nil { + session.error(err) + return + } + } + + session.peer.HeloName = cmd.fields[1] + session.peer.Protocol = ESMTP + + fmt.Fprintf(session.writer, "250-%s\r\n", session.server.Hostname) + + extensions := session.extensions() + + if len(extensions) > 1 { + for _, ext := range extensions[:len(extensions)-1] { + fmt.Fprintf(session.writer, "250-%s\r\n", ext) + } + } + + session.reply(250, extensions[len(extensions)-1]) + + return + +} + +func (session *session) handleMAIL(cmd command) { + if len(cmd.params) != 2 || strings.ToUpper(cmd.params[0]) != "FROM" { + session.reply(502, "Syntax error") + return + } + + if session.peer.HeloName == "" { + session.reply(502, "Please introduce yourself first") + return + } + + if !session.tls && session.server.ForceTLS { + session.reply(502, "Please turn on TLS by issuing a STARTTLS command") + return + } + + if session.envelope != nil { + session.reply(502, "Duplicate MAIL") + return + } + + addr, err := parseAddress(cmd.params[1]) + + if err != nil { + session.reply(502, "Ill-formatted e-mail address") + return + } + + if session.server.SenderChecker != nil { + err = session.server.SenderChecker(session.peer, addr) + if err != nil { + session.error(err) + return + } + } + + session.envelope = &Envelope{ + Sender: addr, + } + + session.reply(250, "Go ahead") + + return + +} + +func (session *session) handleRCPT(cmd command) { + if len(cmd.params) != 2 || strings.ToUpper(cmd.params[0]) != "TO" { + session.reply(502, "Syntax error") + return + } + + if session.envelope == nil { + session.reply(502, "Missing MAIL FROM command.") + return + } + + if len(session.envelope.Recipients) >= session.server.MaxRecipients { + session.reply(452, "Too many recipients") + return + } + + addr, err := parseAddress(cmd.params[1]) + + if err != nil { + session.reply(502, "Ill-formatted e-mail address") + return + } + + if session.server.RecipientChecker != nil { + err = session.server.RecipientChecker(session.peer, addr) + if err != nil { + session.error(err) + return + } + } + + session.envelope.Recipients = append(session.envelope.Recipients, addr) + + session.reply(250, "Go ahead") + + return + +} + +func (session *session) handleSTARTTLS(cmd command) { + + if session.tls { + session.reply(502, "Already running in TLS") + return + } + + if session.server.TLSConfig == nil { + session.reply(502, "TLS not supported") + return + } + + tlsConn := tls.Server(session.conn, session.server.TLSConfig) + session.reply(220, "Go ahead") + + if err := tlsConn.Handshake(); err != nil { + session.reply(550, "Handshake error") + return + } + + // Reset envelope as a new EHLO/HELO is required after STARTTLS + session.reset() + + // Reset deadlines on the underlying connection before I replace it + // with a TLS connection + session.conn.SetDeadline(time.Time{}) + + // Replace connection with a TLS connection + session.conn = tlsConn + session.reader = bufio.NewReader(tlsConn) + session.writer = bufio.NewWriter(tlsConn) + session.scanner = bufio.NewScanner(session.reader) + session.tls = true + + // Save connection state on peer + state := tlsConn.ConnectionState() + session.peer.TLS = &state + + // Flush the connection to set new timeout deadlines + session.flush() + + return + +} + +func (session *session) handleDATA(cmd command) { + + if session.envelope == nil || len(session.envelope.Recipients) == 0 { + session.reply(502, "Missing RCPT TO command.") + return + } + + session.reply(354, "Go ahead. End your data with .") + session.conn.SetDeadline(time.Now().Add(session.server.DataTimeout)) + + data := &bytes.Buffer{} + reader := textproto.NewReader(session.reader).DotReader() + + _, err := io.CopyN(data, reader, int64(session.server.MaxMessageSize)) + + if err == io.EOF { + + // EOF was reached before MaxMessageSize + // Accept and deliver message + + session.envelope.Data = data.Bytes() + + if err := session.deliver(); err != nil { + session.error(err) + } else { + session.reply(250, "Thank you.") + } + + session.reset() + + } + + if err != nil { + // Network error, ignore + return + } + + // Discard the rest and report an error. + _, err = io.Copy(ioutil.Discard, reader) + + if err != nil { + // Network error, ignore + return + } + + session.reply(552, fmt.Sprintf( + "Message exceeded max message size of %d bytes", + session.server.MaxMessageSize, + )) + + session.reset() + + return + +} + +func (session *session) handleRSET(cmd command) { + session.reset() + session.reply(250, "Go ahead") + return +} + +func (session *session) handleNOOP(cmd command) { + session.reply(250, "Go ahead") + return +} + +func (session *session) handleQUIT(cmd command) { + session.reply(221, "OK, bye") + session.close() + return +} + +func (session *session) handleAUTH(cmd command) { + if len(cmd.fields) < 2 { + session.reply(502, "Invalid syntax.") + return + } + + if session.server.Authenticator == nil { + session.reply(502, "AUTH not supported.") + return + } + + if session.peer.HeloName == "" { + session.reply(502, "Please introduce yourself first.") + return + } + + if !session.tls { + session.reply(502, "Cannot AUTH in plain text mode. Use STARTTLS.") + return + } + + mechanism := strings.ToUpper(cmd.fields[1]) + + username := "" + password := "" + + switch mechanism { + + case "PLAIN": + + auth := "" + + if len(cmd.fields) < 3 { + session.reply(334, "Give me your credentials") + if !session.scanner.Scan() { + return + } + auth = session.scanner.Text() + } else { + auth = cmd.fields[2] + } + + data, err := base64.StdEncoding.DecodeString(auth) + + if err != nil { + session.reply(502, "Couldn't decode your credentials") + return + } + + parts := bytes.Split(data, []byte{0}) + + if len(parts) != 3 { + session.reply(502, "Couldn't decode your credentials") + return + } + + username = string(parts[1]) + password = string(parts[2]) + + case "LOGIN": + + session.reply(334, "VXNlcm5hbWU6") + + if !session.scanner.Scan() { + return + } + + byteUsername, err := base64.StdEncoding.DecodeString(session.scanner.Text()) + + if err != nil { + session.reply(502, "Couldn't decode your credentials") + return + } + + session.reply(334, "UGFzc3dvcmQ6") + + if !session.scanner.Scan() { + return + } + + bytePassword, err := base64.StdEncoding.DecodeString(session.scanner.Text()) + + if err != nil { + session.reply(502, "Couldn't decode your credentials") + return + } + + username = string(byteUsername) + password = string(bytePassword) + + default: + + session.reply(502, "Unknown authentication mechanism") + return + + } + + err := session.server.Authenticator(session.peer, username, password) + if err != nil { + session.error(err) + return + } + + session.peer.Username = username + session.peer.Password = password + + session.reply(235, "OK, you are now authenticated") + +} + +func (session *session) handleXCLIENT(cmd command) { + if len(cmd.fields) < 2 { + session.reply(502, "Invalid syntax.") + return + } + + if !session.server.EnableXCLIENT { + session.reply(550, "XCLIENT not enabled") + return + } + + var ( + newHeloName = "" + newAddr net.IP = nil + newTCPPort uint64 = 0 + newUsername = "" + newProto Protocol = "" + ) + + for _, item := range cmd.fields[1:] { + + parts := strings.Split(item, "=") + + if len(parts) != 2 { + session.reply(502, "Couldn't decode the command.") + return + } + + name := parts[0] + value := parts[1] + + switch name { + + case "NAME": + // Unused in smtpd package + continue + + case "HELO": + newHeloName = value + continue + + case "ADDR": + newAddr = net.ParseIP(value) + continue + + case "PORT": + var err error + newTCPPort, err = strconv.ParseUint(value, 10, 16) + if err != nil { + session.reply(502, "Couldn't decode the command.") + return + } + continue + + case "LOGIN": + newUsername = value + continue + + case "PROTO": + if value == "SMTP" { + newProto = SMTP + } else if value == "ESMTP" { + newProto = ESMTP + } + continue + + default: + session.reply(502, "Couldn't decode the command.") + return + } + + } + + tcpAddr, ok := session.peer.Addr.(*net.TCPAddr) + if !ok { + session.reply(502, "Unsupported network connection") + return + } + + if newHeloName != "" { + session.peer.HeloName = newHeloName + } + + if newAddr != nil { + tcpAddr.IP = newAddr + } + + if newTCPPort != 0 { + tcpAddr.Port = int(newTCPPort) + } + + if newUsername != "" { + session.peer.Username = newUsername + } + + if newProto != "" { + session.peer.Protocol = newProto + } + + session.welcome() + +} diff --git a/smtpd.go b/smtpd.go new file mode 100644 index 0000000..08f84e0 --- /dev/null +++ b/smtpd.go @@ -0,0 +1,324 @@ +// Package smtpd implements an SMTP server with support for STARTTLS, authentication (PLAIN/LOGIN), XCLIENT and optional restrictions on the different stages of the SMTP session. +package smtpd + +import ( + "bufio" + "crypto/tls" + "fmt" + "log" + "net" + "time" +) + +// Server defines the parameters for running the SMTP server +type Server struct { + Hostname string // Server hostname. (default: "localhost.localdomain") + WelcomeMessage string // Initial server banner. (default: " ESMTP ready.") + + ReadTimeout time.Duration // Socket timeout for read operations. (default: 60s) + WriteTimeout time.Duration // Socket timeout for write operations. (default: 60s) + DataTimeout time.Duration // Socket timeout for DATA command (default: 5m) + + MaxConnections int // Max concurrent connections, use -1 to disable. (default: 100) + MaxMessageSize int // Max message size in bytes. (default: 10240000) + MaxRecipients int // Max RCPT TO calls for each envelope. (default: 100) + + // New e-mails are handed off to this function. + // Can be left empty for a NOOP server. + // If an error is returned, it will be reported in the SMTP session. + Handler func(peer Peer, env Envelope) error + + // Enable various checks during the SMTP session. + // Can be left empty for no restrictions. + // If an error is returned, it will be reported in the SMTP session. + // Use the Error struct for access to error codes. + ConnectionChecker func(peer Peer) error // Called upon new connection. + HeloChecker func(peer Peer, name string) error // Called after HELO/EHLO. + SenderChecker func(peer Peer, addr string) error // Called after MAIL FROM. + RecipientChecker func(peer Peer, addr string) error // Called after each RCPT TO. + + // Enable PLAIN/LOGIN authentication, only available after STARTTLS. + // Can be left empty for no authentication support. + Authenticator func(peer Peer, username, password string) error + + EnableXCLIENT bool // Enable XCLIENT support (default: false) + + TLSConfig *tls.Config // Enable STARTTLS support. + ForceTLS bool // Force STARTTLS usage. +} + +// Protocol represents the protocol used in the SMTP session +type Protocol string + +const ( + SMTP Protocol = "SMTP" + ESMTP = "ESMTP" +) + +// Peer represents the client connecting to the server +type Peer struct { + HeloName string // Server name used in HELO/EHLO command + Username string // Username from authentication, if authenticated + Password string // Password from authentication, if authenticated + Protocol Protocol // Protocol used, SMTP or ESMTP + ServerName string // A copy of Server.Hostname + Addr net.Addr // Network address + TLS *tls.ConnectionState // TLS Connection details, if on TLS +} + +// Error represents an Error reported in the SMTP session. +type Error struct { + Code int // The integer error code + Message string // The error message +} + +// Error returns a string representation of the SMTP error +func (e Error) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Message) } + +type session struct { + server *Server + + peer Peer + envelope *Envelope + + conn net.Conn + + reader *bufio.Reader + writer *bufio.Writer + scanner *bufio.Scanner + + tls bool +} + +func (srv *Server) newSession(c net.Conn) (s *session) { + + s = &session{ + server: srv, + conn: c, + reader: bufio.NewReader(c), + writer: bufio.NewWriter(c), + peer: Peer{ + Addr: c.RemoteAddr(), + ServerName: srv.Hostname, + }, + } + + s.scanner = bufio.NewScanner(s.reader) + + return + +} + +// ListenAndServe starts the SMTP server and listens on the address provided +func (srv *Server) ListenAndServe(addr string) error { + + srv.configureDefaults() + + l, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + return srv.Serve(l) +} + +// Serve starts the SMTP server and listens on the Listener provided +func (srv *Server) Serve(l net.Listener) error { + + srv.configureDefaults() + + defer l.Close() + + var limiter chan struct{} + + if srv.MaxConnections > 0 { + limiter = make(chan struct{}, srv.MaxConnections) + } else { + limiter = nil + } + + for { + + conn, e := l.Accept() + if e != nil { + if ne, ok := e.(net.Error); ok && ne.Temporary() { + time.Sleep(time.Second) + continue + } + return e + } + + session := srv.newSession(conn) + + if limiter != nil { + go func() { + select { + case limiter <- struct{}{}: + session.serve() + <-limiter + default: + session.reject() + } + }() + } else { + go session.serve() + } + + } + +} + +func (srv *Server) configureDefaults() { + + if srv.MaxMessageSize == 0 { + srv.MaxMessageSize = 10240000 + } + + if srv.MaxConnections == 0 { + srv.MaxConnections = 100 + } + + if srv.MaxRecipients == 0 { + srv.MaxRecipients = 100 + } + + if srv.ReadTimeout == 0 { + srv.ReadTimeout = time.Second * 60 + } + + if srv.WriteTimeout == 0 { + srv.WriteTimeout = time.Second * 60 + } + + if srv.DataTimeout == 0 { + srv.DataTimeout = time.Minute * 5 + } + + if srv.ForceTLS && srv.TLSConfig == nil { + log.Fatal("Cannot use ForceTLS with no TLSConfig") + } + + if srv.Hostname == "" { + srv.Hostname = "localhost.localdomain" + } + + if srv.WelcomeMessage == "" { + srv.WelcomeMessage = fmt.Sprintf("%s ESMTP ready.", srv.Hostname) + } + +} + +func (session *session) serve() { + + defer session.close() + + session.welcome() + + for { + + for session.scanner.Scan() { + session.handle(session.scanner.Text()) + } + + err := session.scanner.Err() + + if err == bufio.ErrTooLong { + + session.reply(500, "Line too long") + + // Advance reader to the next newline + + session.reader.ReadString('\n') + session.scanner = bufio.NewScanner(session.reader) + + // Reset and have the client start over. + + session.reset() + + continue + } + + break + } + +} + +func (session *session) reject() { + session.reply(421, "Too busy. Try again later.") + session.close() +} + +func (session *session) reset() { + session.envelope = nil +} + +func (session *session) welcome() { + + if session.server.ConnectionChecker != nil { + err := session.server.ConnectionChecker(session.peer) + if err != nil { + session.error(err) + session.close() + return + } + } + + session.reply(220, session.server.WelcomeMessage) + +} + +func (session *session) reply(code int, message string) { + fmt.Fprintf(session.writer, "%d %s\r\n", code, message) + session.flush() +} + +func (session *session) flush() { + session.conn.SetWriteDeadline(time.Now().Add(session.server.WriteTimeout)) + session.writer.Flush() + session.conn.SetReadDeadline(time.Now().Add(session.server.ReadTimeout)) +} + +func (session *session) error(err error) { + if smtpdError, ok := err.(Error); ok { + session.reply(smtpdError.Code, smtpdError.Message) + } else { + session.reply(502, fmt.Sprintf("%s", err)) + } +} + +func (session *session) extensions() []string { + + extensions := []string{ + fmt.Sprintf("SIZE %d", session.server.MaxMessageSize), + "8BITMIME", + "PIPELINING", + } + + if session.server.EnableXCLIENT { + extensions = append(extensions, "XCLIENT") + } + + if session.server.TLSConfig != nil && !session.tls { + extensions = append(extensions, "STARTTLS") + } + + if session.server.Authenticator != nil && session.tls { + extensions = append(extensions, "AUTH PLAIN LOGIN") + } + + return extensions + +} + +func (session *session) deliver() error { + if session.server.Handler != nil { + return session.server.Handler(session.peer, *session.envelope) + } + return nil +} + +func (session *session) close() { + session.writer.Flush() + time.Sleep(200 * time.Millisecond) + session.conn.Close() +} diff --git a/smtpd_test.go b/smtpd_test.go new file mode 100644 index 0000000..3f6a2a0 --- /dev/null +++ b/smtpd_test.go @@ -0,0 +1,1162 @@ +package smtpd_test + +import ( + "bytes" + "crypto/tls" + "errors" + "fmt" + "net" + "net/smtp" + "net/textproto" + "strings" + "testing" + "time" + + "bitbucket.org/chrj/smtpd" +) + +var localhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIIBkzCCAT+gAwIBAgIQf4LO8+QzcbXRHJUo6MvX7zALBgkqhkiG9w0BAQswEjEQ +MA4GA1UEChMHQWNtZSBDbzAeFw03MDAxMDEwMDAwMDBaFw04MTA1MjkxNjAwMDBa +MBIxEDAOBgNVBAoTB0FjbWUgQ28wXDANBgkqhkiG9w0BAQEFAANLADBIAkEAx2Uj +2nl0ESnMMrdUOwQnpnIPQzQBX9MIYT87VxhHzImOukWcq5DrmN1ZB//diyrgiCLv +D0udX3YXNHMn1Ki8awIDAQABo3MwcTAOBgNVHQ8BAf8EBAMCAKQwEwYDVR0lBAww +CgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUwAwEB/zA5BgNVHREEMjAwggtleGFtcGxl +LmNvbYIJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMAsGCSqGSIb3 +DQEBCwNBAGcaB2Il0TIXFcJOdOLGPa6F8qZH1ZHBtVlCBnaJn4vZJGzID+V36Gn0 +hA1AYfGAaF0c43oQofvv+XqQlTe4a+M= +-----END CERTIFICATE-----`) + +var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBPAIBAAJBAMdlI9p5dBEpzDK3VDsEJ6ZyD0M0AV/TCGE/O1cYR8yJjrpFnKuQ +65jdWQf/3Ysq4Igi7w9LnV92FzRzJ9SovGsCAwEAAQJAVaFw2VWJbAmIQUuMJ+Ar +6wZW2aSO5okpsyHFqSyrQQIcAj/QOq8P83F8J10IreFWNlBlywJU9c7IlJtn/lqq +AQIhAOxHXOxrKPxqTIdIcNnWye/HRQ+5VD54QQr1+M77+bEBAiEA2AmsNNqj2fKj +j2xk+4vnBSY0vrb4q/O3WZ46oorawWsCIQDWdpfzx/i11E6OZMR6FinJSNh4w0Gi +SkjPiCBE0BX+AQIhAI/TiLk7YmBkQG3ovSYW0vvDntPlXpKj08ovJFw4U0D3AiEA +lGjGna4oaauI0CWI6pG0wg4zklTnrDWK7w9h/S/T4e0= +-----END RSA PRIVATE KEY-----`) + +func cmd(c *textproto.Conn, expectedCode int, format string, args ...interface{}) error { + id, err := c.Cmd(format, args...) + if err != nil { + return err + } + + c.StartResponse(id) + _, _, err = c.ReadResponse(expectedCode) + c.EndResponse(id) + + return err +} + +func runserver(t *testing.T, server *smtpd.Server) (addr string, closer func()) { + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + go func() { + server.Serve(ln) + }() + + done := make(chan bool) + + go func() { + <-done + ln.Close() + }() + + return ln.Addr().String(), func() { + done <- true + } + +} + +func runsslserver(t *testing.T, server *smtpd.Server) (addr string, closer func()) { + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Fatalf("Cert load failed: %v", err) + } + + server.TLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + return runserver(t, server) + +} + +func TestSMTP(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{}) + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("HELO failed: %v", err) + } + + if supported, _ := c.Extension("AUTH"); supported { + t.Fatal("AUTH supported before TLS") + } + + if supported, _ := c.Extension("8BITMIME"); !supported { + t.Fatal("8BITMIME not supported") + } + + if supported, _ := c.Extension("STARTTLS"); supported { + t.Fatal("STARTTLS supported") + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("Mail failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("Rcpt failed: %v", err) + } + + if err := c.Rcpt("recipient2@example.net"); err != nil { + t.Fatalf("Rcpt2 failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err != nil { + t.Fatalf("Data close failed: %v", err) + } + + if err := c.Reset(); err != nil { + t.Fatalf("Reset failed: %v", err) + } + + if err := c.Verify("foobar@example.net"); err == nil { + t.Fatal("Unexpected support for VRFY") + } + + if err := cmd(c.Text, 250, "NOOP"); err != nil { + t.Fatalf("NOOP failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } +} + +func TestListenAndServe(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{}) + closer() + + server := &smtpd.Server{} + + go func() { + server.ListenAndServe(addr) + }() + + time.Sleep(100 * time.Millisecond) + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + +} + +func TestSTARTTLS(t *testing.T) { + + addr, closer := runsslserver(t, &smtpd.Server{ + Authenticator: func(peer smtpd.Peer, username, password string) error { return nil }, + ForceTLS: true, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if supported, _ := c.Extension("AUTH"); supported { + t.Fatal("AUTH supported before TLS") + } + + if err := c.Mail("sender@example.org"); err == nil { + t.Fatal("Mail workded before TLS with ForceTLS") + } + + if err := cmd(c.Text, 220, "STARTTLS"); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + if err := cmd(c.Text, 250, "foobar"); err == nil { + t.Fatal("STARTTLS didn't fail with invalid handshake") + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err == nil { + t.Fatal("STARTTLS worked twice") + } + + if supported, _ := c.Extension("AUTH"); !supported { + t.Fatal("AUTH not supported after TLS") + } + + if _, mechs := c.Extension("AUTH"); !strings.Contains(mechs, "PLAIN") { + t.Fatal("PLAIN AUTH not supported after TLS") + } + + if _, mechs := c.Extension("AUTH"); !strings.Contains(mechs, "LOGIN") { + t.Fatal("LOGIN AUTH not supported after TLS") + } + + if err := c.Auth(smtp.PlainAuth("foo", "foo", "bar", "127.0.0.1")); err != nil { + t.Fatalf("Auth failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("Mail failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("Rcpt failed: %v", err) + } + + if err := c.Rcpt("recipient2@example.net"); err != nil { + t.Fatalf("Rcpt2 failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err != nil { + t.Fatalf("Data close failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } +} + +func TestAuthRejection(t *testing.T) { + + addr, closer := runsslserver(t, &smtpd.Server{ + Authenticator: func(peer smtpd.Peer, username, password string) error { + return smtpd.Error{Code: 550, Message: "Denied"} + }, + ForceTLS: true, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + if err := c.Auth(smtp.PlainAuth("foo", "foo", "bar", "127.0.0.1")); err == nil { + t.Fatal("Auth worked despite rejection") + } + +} + +func TestAuthNotSupported(t *testing.T) { + + addr, closer := runsslserver(t, &smtpd.Server{ + ForceTLS: true, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + if err := c.Auth(smtp.PlainAuth("foo", "foo", "bar", "127.0.0.1")); err == nil { + t.Fatal("Auth worked despite no authenticator") + } + +} + +func TestConnectionCheck(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + ConnectionChecker: func(peer smtpd.Peer) error { + return smtpd.Error{Code: 552, Message: "Denied"} + }, + }) + + defer closer() + + if _, err := smtp.Dial(addr); err == nil { + t.Fatal("Dial succeeded despite ConnectionCheck") + } + +} + +func TestConnectionCheckSimpleError(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + ConnectionChecker: func(peer smtpd.Peer) error { + return errors.New("Denied") + }, + }) + + defer closer() + + if _, err := smtp.Dial(addr); err == nil { + t.Fatal("Dial succeeded despite ConnectionCheck") + } + +} + +func TestHELOCheck(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + HeloChecker: func(peer smtpd.Peer, name string) error { + if name != "foobar.local" { + t.Fatal("Wrong HELO name") + } + return smtpd.Error{Code: 552, Message: "Denied"} + }, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Hello("foobar.local"); err == nil { + t.Fatal("Unexpected HELO success") + } + +} + +func TestSenderCheck(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + SenderChecker: func(peer smtpd.Peer, addr string) error { + return smtpd.Error{Code: 552, Message: "Denied"} + }, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err == nil { + t.Fatal("Unexpected MAIL success") + } + +} + +func TestRecipientCheck(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + RecipientChecker: func(peer smtpd.Peer, addr string) error { + return smtpd.Error{Code: 552, Message: "Denied"} + }, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("Mail failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err == nil { + t.Fatal("Unexpected RCPT success") + } + +} + +func TestMaxMessageSize(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + MaxMessageSize: 5, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err == nil { + t.Fatal("Allowed message larger than 5 bytes to pass.") + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + +} + +func TestHandler(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + if env.Sender != "sender@example.org" { + t.Fatalf("Unknown sender: %v", env.Sender) + } + if len(env.Recipients) != 1 { + t.Fatalf("Too many recipients: %d", len(env.Recipients)) + } + if env.Recipients[0] != "recipient@example.net" { + t.Fatalf("Unknown recipient: %v", env.Recipients[0]) + } + if string(env.Data) != "This is the email body\n" { + t.Fatalf("Wrong message body: %v", string(env.Data)) + } + return nil + }, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err != nil { + t.Fatalf("Data close failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + +} + +func TestRejectHandler(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + return smtpd.Error{Code: 550, Message: "Rejected"} + }, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err == nil { + t.Fatal("Unexpected accept of data") + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + +} + +func TestMaxConnections(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + MaxConnections: 1, + }) + + defer closer() + + c1, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + _, err = smtp.Dial(addr) + if err == nil { + t.Fatal("Dial succeeded despite MaxConnections = 1") + } + + c1.Close() +} + +func TestNoMaxConnections(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + MaxConnections: -1, + }) + + defer closer() + + c1, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + c1.Close() +} + +func TestMaxRecipients(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + MaxRecipients: 1, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err == nil { + t.Fatal("RCPT succeeded despite MaxRecipients = 1") + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + +} + +func TestInvalidHelo(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{}) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Hello(""); err == nil { + t.Fatal("Unexpected HELO success") + } + +} + +func TestInvalidSender(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{}) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("invalid@@example.org"); err == nil { + t.Fatal("Unexpected MAIL success") + } + +} + +func TestInvalidRecipient(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{}) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("Mail failed: %v", err) + } + + if err := c.Rcpt("invalid@@example.org"); err == nil { + t.Fatal("Unexpected RCPT success") + } + +} + +func TestRCPTbeforeMAIL(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{}) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err == nil { + t.Fatal("Unexpected RCPT success") + } + +} + +func TestDATAbeforeRCPT(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{}) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if _, err := c.Data(); err == nil { + t.Fatal("Data accepted despite no recipients") + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + +} + +func TestInterruptedDATA(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + t.Fatal("Accepted DATA despite disconnection") + return nil + }, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + c.Close() + +} + +func TestTimeoutClose(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + MaxConnections: 1, + ReadTimeout: time.Second, + WriteTimeout: time.Second, + }) + + defer closer() + + c1, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + time.Sleep(time.Second * 2) + + c2, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c1.Mail("sender@example.org"); err == nil { + t.Fatal("MAIL succeeded despite being timed out.") + } + + if err := c2.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c2.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + + c2.Close() +} + +func TestTLSTimeout(t *testing.T) { + + addr, closer := runsslserver(t, &smtpd.Server{ + ReadTimeout: time.Second * 2, + WriteTimeout: time.Second * 2, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + time.Sleep(time.Second) + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + time.Sleep(time.Second) + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + time.Sleep(time.Second) + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + time.Sleep(time.Second) + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + +} + +func TestLongLine(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{}) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail(fmt.Sprintf("%s@example.org", strings.Repeat("x", 65*1024))); err == nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + +} + +func TestXCLIENT(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{ + EnableXCLIENT: true, + SenderChecker: func(peer smtpd.Peer, addr string) error { + if peer.HeloName != "new.example.net" { + t.Fatalf("Didn't override HELO name: %v", peer.HeloName) + } + if peer.Addr.String() != "42.42.42.42:4242" { + t.Fatalf("Didn't override IP/Port: %v", peer.Addr) + } + if peer.Username != "newusername" { + t.Fatalf("Didn't override username: %v", peer.Username) + } + if peer.Protocol != smtpd.SMTP { + t.Fatalf("Didn't override protocol: %v", peer.Protocol) + } + return nil + }, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if supported, _ := c.Extension("XCLIENT"); !supported { + t.Fatal("XCLIENT not supported") + } + + err = cmd(c.Text, 220, "XCLIENT NAME=ignored ADDR=42.42.42.42 PORT=4242 PROTO=SMTP HELO=new.example.net LOGIN=newusername") + if err != nil { + t.Fatalf("XCLIENT failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("Mail failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("Rcpt failed: %v", err) + } + + if err := c.Rcpt("recipient2@example.net"); err != nil { + t.Fatalf("Rcpt2 failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err != nil { + t.Fatalf("Data close failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + +} + +func TestEnvelopeReceived(t *testing.T) { + + addr, closer := runsslserver(t, &smtpd.Server{ + Hostname: "foobar.example.net", + Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + env.AddReceivedLine(peer) + if !bytes.HasPrefix(env.Data, []byte("Received: from localhost [127.0.0.1] by foobar.example.net with ESMTP;")) { + t.Fatal("Wrong received line.") + } + return nil + }, + ForceTLS: true, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err != nil { + t.Fatalf("Data close failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + +} + +func TestHELO(t *testing.T) { + + addr, closer := runserver(t, &smtpd.Server{}) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := cmd(c.Text, 502, "MAIL FROM:"); err != nil { + t.Fatalf("MAIL didn't fail: %v", err) + } + + if err := cmd(c.Text, 250, "HELO localhost"); err != nil { + t.Fatalf("HELO failed: %v", err) + } + + if err := cmd(c.Text, 502, "MAIL FROM:christian@technobabble.dk"); err != nil { + t.Fatalf("MAIL didn't fail: %v", err) + } + + if err := cmd(c.Text, 250, "HELO localhost"); err != nil { + t.Fatalf("HELO failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + +} + +func TestLOGINAuth(t *testing.T) { + + addr, closer := runsslserver(t, &smtpd.Server{ + Authenticator: func(peer smtpd.Peer, username, password string) error { return nil }, + }) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + if err := cmd(c.Text, 334, "AUTH LOGIN"); err != nil { + t.Fatalf("AUTH didn't work: %v", err) + } + + if err := cmd(c.Text, 502, "foo"); err != nil { + t.Fatalf("AUTH didn't fail: %v", err) + } + + if err := cmd(c.Text, 334, "AUTH LOGIN"); err != nil { + t.Fatalf("AUTH didn't work: %v", err) + } + + if err := cmd(c.Text, 334, "Zm9v"); err != nil { + t.Fatalf("AUTH didn't work: %v", err) + } + + if err := cmd(c.Text, 502, "foo"); err != nil { + t.Fatalf("AUTH didn't fail: %v", err) + } + + if err := cmd(c.Text, 334, "AUTH LOGIN"); err != nil { + t.Fatalf("AUTH didn't work: %v", err) + } + + if err := cmd(c.Text, 334, "Zm9v"); err != nil { + t.Fatalf("AUTH didn't work: %v", err) + } + + if err := cmd(c.Text, 235, "Zm9v"); err != nil { + t.Fatalf("AUTH didn't work: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + +} + +func TestErrors(t *testing.T) { + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Fatalf("Cert load failed: %v", err) + } + + server := &smtpd.Server{ + Authenticator: func(peer smtpd.Peer, username, password string) error { return nil }, + } + + addr, closer := runserver(t, server) + + defer closer() + + c, err := smtp.Dial(addr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := cmd(c.Text, 502, "AUTH PLAIN foobar"); err != nil { + t.Fatalf("AUTH didn't fail: %v", err) + } + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("HELO failed: %v", err) + } + + if err := cmd(c.Text, 502, "AUTH PLAIN foobar"); err != nil { + t.Fatalf("AUTH didn't fail: %v", err) + } + + if err := cmd(c.Text, 502, "MAIL FROM:christian@technobabble.dk"); err != nil { + t.Fatalf("MAIL didn't fail: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err == nil { + t.Fatal("Duplicate MAIL didn't fail") + } + + if err := cmd(c.Text, 502, "STARTTLS"); err != nil { + t.Fatalf("STARTTLS didn't fail: %v", err) + } + + server.TLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + if err := cmd(c.Text, 502, "AUTH UNKNOWN"); err != nil { + t.Fatalf("AUTH didn't fail: %v", err) + } + + if err := cmd(c.Text, 502, "AUTH PLAIN foobar"); err != nil { + t.Fatalf("AUTH didn't fail: %v", err) + } + + if err := cmd(c.Text, 502, "AUTH PLAIN Zm9vAGJhcg=="); err != nil { + t.Fatalf("AUTH didn't fail: %v", err) + } + + if err := cmd(c.Text, 334, "AUTH PLAIN"); err != nil { + t.Fatalf("AUTH didn't work: %v", err) + } + + if err := cmd(c.Text, 235, "Zm9vAGJhcgBxdXV4"); err != nil { + t.Fatalf("AUTH didn't work: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + +} diff --git a/wrap.go b/wrap.go new file mode 100644 index 0000000..91a6cce --- /dev/null +++ b/wrap.go @@ -0,0 +1,22 @@ +package smtpd + +// Wrap a byte slice paragraph for use in SMTP header +func wrap(sl []byte) []byte { + length := 0 + for i := 0; i < len(sl); i++ { + if length > 76 && sl[i] == ' ' { + sl = append(sl, 0, 0) + copy(sl[i+2:], sl[i:]) + sl[i] = '\r' + sl[i+1] = '\n' + sl[i+2] = '\t' + i += 2 + length = 0 + } + if sl[i] == '\n' { + length = 0 + } + length++ + } + return sl +} diff --git a/wrap_test.go b/wrap_test.go new file mode 100644 index 0000000..a8b65de --- /dev/null +++ b/wrap_test.go @@ -0,0 +1,24 @@ +package smtpd + +import ( + "testing" +) + +func TestWrap(t *testing.T) { + + cases := map[string]string{ + "foobar": "foobar", + "foobar quux": "foobar quux", + "foobar\r\n": "foobar\r\n", + "foobar\r\nquux": "foobar\r\nquux", + "foobar quux foobar quux foobar quux foobar quux foobar quux foobar quux foobar quux foobar quux": "foobar quux foobar quux foobar quux foobar quux foobar quux foobar quux foobar\r\n\tquux foobar quux", + "foobar quux foobar quux foobar quux foobar quux foobar quux foobar\r\n\tquux foobar quux foobar quux": "foobar quux foobar quux foobar quux foobar quux foobar quux foobar\r\n\tquux foobar quux foobar quux", + } + + for k, v := range cases { + if string(wrap([]byte(k))) != v { + t.Fatal("Didn't wrap correctly.") + } + } + +}