diff --git a/conn.go b/conn.go index 85e212f..d12b8eb 100644 --- a/conn.go +++ b/conn.go @@ -9,10 +9,8 @@ import ( "crypto/tls" "errors" "fmt" - "io" "net" "net/url" - "sync" "time" "golang.org/x/net/context" @@ -28,9 +26,8 @@ var endline = []byte("\r\n") // ircConn represents an IRC network protocol connection, it consists of an // Encoder and Decoder to manage i/o. type ircConn struct { - ircEncoder - ircDecoder - lconn net.Conn + io *bufio.ReadWriter + sock net.Conn // lastWrite is used ot keep track of when we last wrote to the server. lastWrite time.Time @@ -108,106 +105,43 @@ func newConn(conf Config, addr string) (*ircConn, error) { } if conf.SSL { - var sslConf *tls.Config - - if conf.TLSConfig == nil { - sslConf = &tls.Config{ServerName: conf.Server} - } else { - sslConf = conf.TLSConfig + var tlsConn net.Conn + tlsConn, err = tlsHandshake(conn, conf.TLSConfig, conf.Server, true) + if err != nil { + return nil, err } - tlsConn := tls.Client(conn, sslConf) - if err = tlsConn.Handshake(); err != nil { - return nil, fmt.Errorf("failed handshake during tls conn to %q: %s", addr, err) - } conn = tlsConn } ctime := time.Now() - return &ircConn{ - ircEncoder: ircEncoder{writer: conn}, - ircDecoder: ircDecoder{reader: bufio.NewReader(conn)}, - lconn: conn, - connTime: &ctime, - connected: true, - }, nil + c := &ircConn{ + sock: conn, + connTime: &ctime, + connected: true, + } + c.newReadWriter() + + return c, nil } -// Close closes the underlying ReadWriteCloser. +func (c *ircConn) newReadWriter() { + c.io = bufio.NewReadWriter(bufio.NewReader(c.sock), bufio.NewWriter(c.sock)) +} + +func tlsHandshake(conn net.Conn, conf *tls.Config, server string, validate bool) (net.Conn, error) { + if conf == nil { + conf = &tls.Config{ServerName: server, InsecureSkipVerify: !validate} + } + + tlsConn := tls.Client(conn, conf) + return net.Conn(tlsConn), nil +} + +// Close closes the underlying socket. func (c *ircConn) Close() error { - return c.lconn.Close() -} - -// setTimeout applies a deadline that the connection must respond back with, -// within the specified time. -func (c *ircConn) setTimeout(timeout time.Duration) { - c.lconn.SetDeadline(time.Now().Add(timeout)) -} - -// rate allows limiting events based on how frequent the event is being sent, -// as well as how many characters each event has. -func (c *ircConn) rate(chars int) time.Duration { - _time := time.Second + ((time.Duration(chars) * time.Second) / 100) - if c.writeDelay += _time - time.Now().Sub(c.lastWrite); c.writeDelay < 0 { - c.writeDelay = 0 - } - - if c.writeDelay > (8 * time.Second) { - return _time - } - - return 0 -} - -// ircDecoder reads Event objects from an input stream. -type ircDecoder struct { - reader *bufio.Reader - line string - mu sync.Mutex -} - -// Decode attempts to read a single Event from the stream, returns non-nil -// error if read failed. event may be nil if unparseable. -func (dec *ircDecoder) Decode() (event *Event, err error) { - dec.mu.Lock() - dec.line, err = dec.reader.ReadString(delim) - dec.mu.Unlock() - - if err != nil { - return nil, err - } - - return ParseEvent(dec.line), nil -} - -// ircEncoder writes Event objects to an output stream. -type ircEncoder struct { - writer io.Writer - mu sync.Mutex -} - -// Encode writes the IRC encoding of m to the stream. Goroutine safe. -// returns non-nil error if the write to the underlying stream stopped early. -func (enc *ircEncoder) Encode(e *Event) (err error) { - _, err = enc.Write(e.Bytes()) - - return -} - -// Write writes len(p) bytes from p followed by CR+LF. Goroutine safe. -func (enc *ircEncoder) Write(p []byte) (n int, err error) { - enc.mu.Lock() - defer enc.mu.Unlock() - - n, err = enc.writer.Write(p) - if err != nil { - return - } - - _, err = enc.writer.Write(endline) - - return + return c.sock.Close() } // Connect attempts to connect to the given IRC server @@ -341,6 +275,7 @@ func (c *Client) disconnectHandler(err error) { // IRC server. If there is an error, it calls Reconnect. func (c *Client) readLoop(ctx context.Context) { var event *Event + var line string var err error for { @@ -348,8 +283,8 @@ func (c *Client) readLoop(ctx context.Context) { case <-ctx.Done(): return default: - c.conn.setTimeout(300 * time.Second) - event, err = c.conn.Decode() + c.conn.sock.SetDeadline(time.Now().Add(300 * time.Second)) + line, err = c.conn.io.ReadString(delim) if err != nil { // Attempt a reconnect (if applicable). If it fails, send // the error to c.Config.HandleError to be dealt with, if @@ -359,6 +294,7 @@ func (c *Client) readLoop(ctx context.Context) { return } + event = ParseEvent(line) if event == nil { continue } @@ -384,7 +320,24 @@ func (c *Client) write(event *Event) { c.tx <- event } +// rate allows limiting events based on how frequent the event is being sent, +// as well as how many characters each event has. +func (c *ircConn) rate(chars int) time.Duration { + _time := time.Second + ((time.Duration(chars) * time.Second) / 100) + if c.writeDelay += _time - time.Now().Sub(c.lastWrite); c.writeDelay < 0 { + c.writeDelay = 0 + } + + if c.writeDelay > (8 * time.Second) { + return _time + } + + return 0 +} + func (c *Client) sendLoop(ctx context.Context) { + var err error + for { select { case <-ctx.Done(): @@ -402,7 +355,17 @@ func (c *Client) sendLoop(ctx context.Context) { c.conn.lastWrite = time.Now() - err := c.conn.Encode(event) + // Write the raw line. + _, err = c.conn.io.Write(event.Bytes()) + if err == nil { + // And the \r\n. + _, err = c.conn.io.Write(endline) + if err == nil { + // Lastly, flush everything to the socket. + err = c.conn.io.Flush() + } + } + if err != nil { c.disconnectHandler(err) }