From be033068608df109ac85f74ff70e686a32e3d645 Mon Sep 17 00:00:00 2001 From: Christian Joergensen Date: Mon, 14 Jul 2014 19:44:10 +0200 Subject: [PATCH] Test cases, fixes. --- protocol.go | 74 +++++++++++++++++-- smtpd.go | 58 +-------------- smtpd_test.go | 195 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 265 insertions(+), 62 deletions(-) create mode 100644 smtpd_test.go diff --git a/protocol.go b/protocol.go index 4c07fad..348b6e7 100644 --- a/protocol.go +++ b/protocol.go @@ -6,7 +6,6 @@ import ( "crypto/tls" "encoding/base64" "fmt" - "log" "strings" ) @@ -31,6 +30,58 @@ func parseLine(line string) (cmd command) { } +func (session *session) handle(line string) { + + cmd := parseLine(line) + + switch cmd.action { + + case "HELO": + session.handleHELO(cmd) + return + + case "EHLO": + session.handleEHLO(cmd) + return + + case "MAIL": + session.handleMAIL(cmd) + return + + case "RCPT": + session.handleRCPT(cmd) + return + + case "STARTTLS": + session.handleSTARTTLS(cmd) + return + + case "DATA": + session.handleDATA(cmd) + return + + case "RSET": + session.handleRSET(cmd) + return + + case "NOOP": + session.handleNOOP(cmd) + return + + case "QUIT": + session.handleQUIT(cmd) + return + + case "AUTH": + session.handleAUTH(cmd) + return + + } + + session.reply(502, "Unsupported command.") + +} + func (session *session) handleHELO(cmd command) { if len(cmd.fields) < 2 { @@ -163,10 +214,9 @@ func (session *session) handleSTARTTLS(cmd command) { } tlsConn := tls.Server(session.conn, session.server.TLSConfig) - session.reply(250, "Go ahead") + session.reply(220, "Go ahead") if err := tlsConn.Handshake(); err != nil { - log.Printf("TLS Handshake error:", err) session.reply(550, "Handshake error") return } @@ -188,7 +238,7 @@ func (session *session) handleDATA(cmd command) { return } - session.reply(250, "Go ahead. End your data with .") + session.reply(354, "Go ahead. End your data with .") data := &bytes.Buffer{} done := false @@ -226,7 +276,7 @@ func (session *session) handleDATA(cmd command) { if err != nil { session.error(err) } else { - session.reply(200, "Thank you.") + session.reply(250, "Thank you.") } } @@ -243,13 +293,23 @@ func (session *session) handleNOOP(cmd command) { } func (session *session) handleQUIT(cmd command) { - session.reply(250, "OK, bye") + session.reply(221, "OK, bye") session.close() return } func (session *session) handleAUTH(cmd command) { + if session.peer.HeloName == "" { + session.reply(502, "Please introduce yourself first.") + return + } + + if !session.tls { + session.reply(502, "Cannot AUTH in plain text mode. Use STARTTLS.") + return + } + mechanism := strings.ToUpper(cmd.fields[1]) username := "" @@ -335,6 +395,6 @@ func (session *session) handleAUTH(cmd command) { session.peer.Username = username session.peer.Password = password - session.reply(250, "OK, you are now authenticated") + session.reply(235, "OK, you are now authenticated") } diff --git a/smtpd.go b/smtpd.go index c2b0194..6d8fed3 100644 --- a/smtpd.go +++ b/smtpd.go @@ -73,8 +73,6 @@ type session struct { func (srv *Server) newSession(c net.Conn) (s *session, err error) { - log.Printf("New connection from: %s", c.RemoteAddr()) - s = &session{ server: srv, conn: c, @@ -98,7 +96,7 @@ func (srv *Server) ListenAndServe() error { if err != nil { return err } - log.Printf("Listening on: %s", srv.Addr) + return srv.Serve(l) } @@ -169,63 +167,12 @@ func (srv *Server) configureDefaults() { func (session *session) serve() { - log.Print("Serving") - defer session.close() session.reply(220, session.server.WelcomeMessage) for session.scanner.Scan() { - - line := session.scanner.Text() - cmd := parseLine(line) - - switch cmd.action { - - case "HELO": - session.handleHELO(cmd) - continue - - case "EHLO": - session.handleEHLO(cmd) - continue - - case "MAIL": - session.handleMAIL(cmd) - continue - - case "RCPT": - session.handleRCPT(cmd) - continue - - case "STARTTLS": - session.handleSTARTTLS(cmd) - continue - - case "DATA": - session.handleDATA(cmd) - continue - - case "RSET": - session.handleRSET(cmd) - continue - - case "NOOP": - session.handleNOOP(cmd) - continue - - case "QUIT": - session.handleQUIT(cmd) - continue - - case "AUTH": - session.handleAUTH(cmd) - continue - - } - - session.reply(502, "Unsupported command.") - + session.handle(session.scanner.Text()) } } @@ -249,6 +196,7 @@ func (session *session) extensions() []string { extensions := []string{ fmt.Sprintf("SIZE %d", session.server.MaxMessageSize), + "8BITMIME", } if session.server.TLSConfig != nil && !session.tls { diff --git a/smtpd_test.go b/smtpd_test.go new file mode 100644 index 0000000..16d036e --- /dev/null +++ b/smtpd_test.go @@ -0,0 +1,195 @@ +package smtpd + +import ( + "crypto/tls" + "fmt" + "net" + "net/smtp" + "strings" + "testing" +) + +var localhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIIBkzCCAT+gAwIBAgIQf4LO8+QzcbXRHJUo6MvX7zALBgkqhkiG9w0BAQswEjEQ +MA4GA1UEChMHQWNtZSBDbzAeFw03MDAxMDEwMDAwMDBaFw04MTA1MjkxNjAwMDBa +MBIxEDAOBgNVBAoTB0FjbWUgQ28wXDANBgkqhkiG9w0BAQEFAANLADBIAkEAx2Uj +2nl0ESnMMrdUOwQnpnIPQzQBX9MIYT87VxhHzImOukWcq5DrmN1ZB//diyrgiCLv +D0udX3YXNHMn1Ki8awIDAQABo3MwcTAOBgNVHQ8BAf8EBAMCAKQwEwYDVR0lBAww +CgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUwAwEB/zA5BgNVHREEMjAwggtleGFtcGxl +LmNvbYIJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMAsGCSqGSIb3 +DQEBCwNBAGcaB2Il0TIXFcJOdOLGPa6F8qZH1ZHBtVlCBnaJn4vZJGzID+V36Gn0 +hA1AYfGAaF0c43oQofvv+XqQlTe4a+M= +-----END CERTIFICATE-----`) + +var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIBPAIBAAJBAMdlI9p5dBEpzDK3VDsEJ6ZyD0M0AV/TCGE/O1cYR8yJjrpFnKuQ +65jdWQf/3Ysq4Igi7w9LnV92FzRzJ9SovGsCAwEAAQJAVaFw2VWJbAmIQUuMJ+Ar +6wZW2aSO5okpsyHFqSyrQQIcAj/QOq8P83F8J10IreFWNlBlywJU9c7IlJtn/lqq +AQIhAOxHXOxrKPxqTIdIcNnWye/HRQ+5VD54QQr1+M77+bEBAiEA2AmsNNqj2fKj +j2xk+4vnBSY0vrb4q/O3WZ46oorawWsCIQDWdpfzx/i11E6OZMR6FinJSNh4w0Gi +SkjPiCBE0BX+AQIhAI/TiLk7YmBkQG3ovSYW0vvDntPlXpKj08ovJFw4U0D3AiEA +lGjGna4oaauI0CWI6pG0wg4zklTnrDWK7w9h/S/T4e0= +-----END RSA PRIVATE KEY-----`) + +func TestSMTP(t *testing.T) { + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + defer ln.Close() + + server := &Server{} + + go func() { + server.Serve(ln) + }() + + c, err := smtp.Dial(ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if supported, _ := c.Extension("AUTH"); supported { + t.Fatal("AUTH supported before TLS") + } + + if supported, _ := c.Extension("8BITMIME"); !supported { + t.Fatal("8BITMIME not supported") + } + + if supported, _ := c.Extension("STARTTLS"); supported { + t.Fatal("STARTTLS supported") + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("Mail failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("Rcpt failed: %v", err) + } + + if err := c.Rcpt("recipient2@example.net"); err != nil { + t.Fatalf("Rcpt2 failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err != nil { + t.Fatalf("Data close failed: %v", err) + } + + if err := c.Reset(); err != nil { + t.Fatalf("Reset failed: %v", err) + } + + if err := c.Verify("foobar@example.net"); err == nil { + t.Fatal("Unexpected support for VRFY") + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } +} + +func TestSTARTTLS(t *testing.T) { + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + + defer ln.Close() + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + if err != nil { + t.Fatalf("Cert load failed: %v", err) + } + + server := &Server{ + Authenticator: func(peer Peer, username, password string) error { return nil }, + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + ForceTLS: true, + } + + go func() { + server.Serve(ln) + }() + + c, err := smtp.Dial(ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + + if supported, _ := c.Extension("AUTH"); supported { + t.Fatal("AUTH supported before TLS") + } + + if err := c.Mail("sender@example.org"); err == nil { + t.Fatal("Mail workded before TLS with ForceTLS") + } + + if err := c.StartTLS(&tls.Config{InsecureSkipVerify: true}); err != nil { + t.Fatalf("STARTTLS failed: %v", err) + } + + if supported, _ := c.Extension("AUTH"); !supported { + t.Fatal("AUTH not supported after TLS") + } + + if _, mechs := c.Extension("AUTH"); !strings.Contains(mechs, "PLAIN") { + t.Fatal("PLAIN AUTH not supported after TLS") + } + + if _, mechs := c.Extension("AUTH"); !strings.Contains(mechs, "LOGIN") { + t.Fatal("LOGIN AUTH not supported after TLS") + } + + if err := c.Auth(smtp.PlainAuth("foo", "foo", "bar", "127.0.0.1")); err != nil { + t.Fatalf("Auth failed: %v", err) + } + + if err := c.Mail("sender@example.org"); err != nil { + t.Fatalf("Mail failed: %v", err) + } + + if err := c.Rcpt("recipient@example.net"); err != nil { + t.Fatalf("Rcpt failed: %v", err) + } + + if err := c.Rcpt("recipient2@example.net"); err != nil { + t.Fatalf("Rcpt2 failed: %v", err) + } + + wc, err := c.Data() + if err != nil { + t.Fatalf("Data failed: %v", err) + } + + _, err = fmt.Fprintf(wc, "This is the email body") + if err != nil { + t.Fatalf("Data body failed: %v", err) + } + + err = wc.Close() + if err != nil { + t.Fatalf("Data close failed: %v", err) + } + + if err := c.Quit(); err != nil { + t.Fatalf("Quit failed: %v", err) + } +}