diff --git a/address.go b/address.go index 76b7507..84bf735 100644 --- a/address.go +++ b/address.go @@ -6,8 +6,14 @@ import ( ) func parseAddress(src string) (string, error) { - if src[0] != '<' || src[len(src)-1] != '>' || strings.Count(src, "@") != 1 { + + if src[0] != '<' || src[len(src)-1] != '>' { return "", fmt.Errorf("Ill-formatted e-mail address: %s", src) } + + if strings.Count(src, "@") > 1 { + return "", fmt.Errorf("Ill-formatted e-mail address: %s", src) + } + return src[1 : len(src)-1], nil } diff --git a/example_test.go b/example_test.go index f0c669d..a143890 100644 --- a/example_test.go +++ b/example_test.go @@ -27,6 +27,7 @@ func ExampleServer() { }, Handler: func(peer smtpd.Peer, env smtpd.Envelope) error { + return smtp.SendMail( "smtp.gmail.com:587", smtp.PlainAuth( @@ -39,6 +40,7 @@ func ExampleServer() { env.Recipients, env.Data, ) + }, } diff --git a/protocol.go b/protocol.go index 6465b29..4a17c8e 100644 --- a/protocol.go +++ b/protocol.go @@ -6,7 +6,11 @@ import ( "crypto/tls" "encoding/base64" "fmt" + "io" + "io/ioutil" + "net/textproto" "strings" + "time" ) type command struct { @@ -178,6 +182,11 @@ func (session *session) handleRCPT(cmd command) { return } + if len(session.envelope.Recipients) >= session.server.MaxRecipients { + session.reply(550, "Too many recipients") + return + } + addr, err := parseAddress(cmd.params[1]) if err != nil { @@ -219,12 +228,23 @@ func (session *session) handleSTARTTLS(cmd command) { return } + // Reset HeloName as a new EHLO/HELO is required after STARTTLS + session.peer.HeloName = "" + + // Reset deadlines on the underlying connection before I replace it + // with a TLS connection + session.conn.SetDeadline(time.Time{}) + + // Replace connection with a TLS connection session.conn = tlsConn session.reader = bufio.NewReader(tlsConn) session.writer = bufio.NewWriter(tlsConn) session.scanner = bufio.NewScanner(session.reader) session.tls = true + // Flush the connection to set new timeout deadlines + session.flush() + return } @@ -237,46 +257,47 @@ func (session *session) handleDATA(cmd command) { } session.reply(354, "Go ahead. End your data with .") + session.conn.SetDeadline(time.Now().Add(session.server.DataTimeout)) data := &bytes.Buffer{} - done := false + reader := textproto.NewReader(session.reader).DotReader() - for session.scanner.Scan() { + _, err := io.CopyN(data, reader, int64(session.server.MaxMessageSize)) - line := session.scanner.Text() + if err == io.EOF { - if line == "." { - done = true - break + // EOF was reached before MaxMessageSize + // Accept and deliver message + + session.envelope.Data = data.Bytes() + if err := session.deliver(); err != nil { + session.error(err) + } else { + session.reply(250, "Thank you.") } - data.Write([]byte(line)) - data.Write([]byte("\r\n")) - } - if !done { - return - } - - if data.Len() > session.server.MaxMessageSize { - session.reply(550, fmt.Sprintf( - "Message exceeded max message size of %d bytes", - session.server.MaxMessageSize, - )) - return - } - - session.envelope.Data = data.Bytes() - - err := session.deliver() - if err != nil { - session.error(err) - } else { - session.reply(250, "Thank you.") + // Network error, ignore + return } + // Discard the rest and report an error. + _, err = io.Copy(ioutil.Discard, reader) + + if err != nil { + // Network error, ignore + return + } + + session.reply(552, fmt.Sprintf( + "Message exceeded max message size of %d bytes", + session.server.MaxMessageSize, + )) + + return + } func (session *session) handleRSET(cmd command) { @@ -298,7 +319,7 @@ func (session *session) handleQUIT(cmd command) { func (session *session) handleAUTH(cmd command) { - if session.server.Authenticator == nil { + if session.server.Authenticator == nil { session.reply(502, "AUTH not supported.") return } diff --git a/smtpd.go b/smtpd.go index de6d39f..aae2ddb 100644 --- a/smtpd.go +++ b/smtpd.go @@ -18,9 +18,11 @@ type Server struct { ReadTimeout time.Duration // Socket timeout for read operations. (default: 60s) WriteTimeout time.Duration // Socket timeout for write operations. (default: 60s) + DataTimeout time.Duration // Socket timeout for DATA command (default: 5m) - MaxMessageSize int // Max message size in bytes. (default: 10240000) MaxConnections int // Max concurrent connections, use -1 to disable. (default: 100) + MaxMessageSize int // Max message size in bytes. (default: 10240000) + MaxRecipients int // Max RCPT TO calls for each envelope. (default: 100) // New e-mails are handed off to this function. // Can be left empty for a NOOP server. @@ -168,6 +170,10 @@ func (srv *Server) configureDefaults() { srv.MaxConnections = 100 } + if srv.MaxRecipients == 0 { + srv.MaxRecipients = 100 + } + if srv.ReadTimeout == 0 { srv.ReadTimeout = time.Second * 60 } @@ -176,6 +182,10 @@ func (srv *Server) configureDefaults() { srv.WriteTimeout = time.Second * 60 } + if srv.DataTimeout == 0 { + srv.DataTimeout = time.Minute * 5 + } + if srv.ForceTLS && srv.TLSConfig == nil { log.Fatal("Cannot use ForceTLS with no TLSConfig") } @@ -211,7 +221,7 @@ func (session *session) serve() { } func (session *session) reject() { - session.reply(450, "Too busy. Try again later.") + session.reply(421, "Too busy. Try again later.") session.close() } @@ -231,14 +241,14 @@ func (session *session) welcome() { } func (session *session) reply(code int, message string) { - fmt.Fprintf(session.writer, "%d %s\r\n", code, message) + session.flush() +} +func (session *session) flush() { session.conn.SetWriteDeadline(time.Now().Add(session.server.WriteTimeout)) session.writer.Flush() - session.conn.SetReadDeadline(time.Now().Add(session.server.ReadTimeout)) - } func (session *session) error(err error) { @@ -254,6 +264,7 @@ func (session *session) extensions() []string { extensions := []string{ fmt.Sprintf("SIZE %d", session.server.MaxMessageSize), "8BITMIME", + "PIPELINING", } if session.server.TLSConfig != nil && !session.tls { diff --git a/smtpd_test.go b/smtpd_test.go index 76fb440..17e0539 100644 --- a/smtpd_test.go +++ b/smtpd_test.go @@ -533,8 +533,8 @@ func TestHandler(t *testing.T) { if env.Recipients[0] != "recipient@example.net" { t.Fatalf("Unknown recipient: %v", env.Recipients[0]) } - if string(env.Data) != "This is the email body\r\n" { - t.Fatalf("Wrong message body: %v", env.Data) + if string(env.Data) != "This is the email body\n" { + t.Fatalf("Wrong message body: %v", string(env.Data)) } return nil }, @@ -686,6 +686,46 @@ func TestNoMaxConnections(t *testing.T) { c1.Close() } +func TestMaxRecipients(t *testing.T) { + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + defer ln.Close() + + server := &smtpd.Server{ + MaxRecipients: 1, + } + + go func() { + server.Serve(ln) + }() + + c, err := smtp.Dial(ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err == nil { + t.Fatal("RCPT succeeded despite MaxRecipients = 1") + } + + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %v", err) + } + +} + func TestInvalidHelo(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") @@ -732,7 +772,7 @@ func TestInvalidSender(t *testing.T) { t.Fatalf("Dial failed: %v", err) } - if err := c.Mail("invalid"); err == nil { + if err := c.Mail("invalid@@example.org"); err == nil { t.Fatal("Unexpected MAIL success") } @@ -762,7 +802,7 @@ func TestInvalidRecipient(t *testing.T) { t.Fatalf("Mail failed: %v", err) } - if err := c.Rcpt("invalid"); err == nil { + if err := c.Rcpt("invalid@@example.org"); err == nil { t.Fatal("Unexpected RCPT success") } @@ -878,3 +918,110 @@ func TestInterruptedDATA(t *testing.T) { c.Close() } + +func TestTimeoutClose(t *testing.T) { + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + defer ln.Close() + + server := &smtpd.Server{ + MaxConnections: 1, + ReadTimeout: time.Second, + WriteTimeout: time.Second, + } + + go func() { + server.Serve(ln) + }() + + c1, err := smtp.Dial(ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + time.Sleep(time.Second * 2) + + c2, err := smtp.Dial(ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c1.Mail("sender@example.org"); err == nil { + t.Fatal("MAIL succeeded despite being timed out.") + } + + if err := c2.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + if err := c2.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + + c2.Close() +} + +func TestTLSTimeout(t *testing.T) { + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + defer ln.Close() + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Fatalf("Cert load failed: %v", err) + } + + server := &smtpd.Server{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + ReadTimeout: time.Second * 2, + WriteTimeout: time.Second * 2, + } + + go func() { + server.Serve(ln) + }() + + c, err := smtp.Dial(ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + time.Sleep(time.Second) + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("MAIL failed: %v", err) + } + + time.Sleep(time.Second) + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + time.Sleep(time.Second) + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("RCPT failed: %v", err) + } + + time.Sleep(time.Second) + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } + +}