diff --git a/client.go b/client.go index ade285e..190e54b 100644 --- a/client.go +++ b/client.go @@ -12,13 +12,9 @@ import ( "io" "io/ioutil" "log" - "net" - "net/url" "strings" "sync" "time" - - "golang.org/x/net/proxy" ) // Client contains all of the information necessary to run a single IRC @@ -190,19 +186,6 @@ func New(config Config) *Client { // Connect attempts to connect to the given IRC server func (c *Client) Connect() error { - // Sanity check a few options. - if c.Config.Server == "" { - return errors.New("invalid server specified") - } - - if c.Config.Port < 21 || c.Config.Port > 65535 { - return errors.New("invalid port (21-65535)") - } - - if !IsValidNick(c.Config.Nick) || !IsValidUser(c.Config.User) { - return errors.New("invalid nickname or user") - } - // Clean up any old running stuff. c.cleanup(false) @@ -213,71 +196,17 @@ func (c *Client) Connect() error { // Reset the state. c.state = newState() + // Validate info, and actually make the connection. c.debug.Printf("connecting to %s...", c.Server()) - - var conn net.Conn - var err error - - dialer := &net.Dialer{Timeout: 5 * time.Second} - - if c.Config.Bind != "" { - var local *net.TCPAddr - local, err = net.ResolveTCPAddr("tcp", c.Config.Bind+":0") - if err != nil { - return fmt.Errorf("unable to resolve bind address %s: %s", c.Config.Bind, err) - } - - dialer.LocalAddr = local - } - - if c.Config.Proxy != "" { - var proxyUri *url.URL - var proxyDialer proxy.Dialer - - proxyUri, err = url.Parse(c.Config.Proxy) - if err != nil { - return fmt.Errorf("unable to use proxy %q: %s", c.Config.Proxy, err) - } - - proxyDialer, err = proxy.FromURL(proxyUri, dialer) - if err != nil { - return fmt.Errorf("unable to use proxy %q: %s", c.Config.Proxy, err) - } - - conn, err = proxyDialer.Dial("tcp", c.Server()) - if err != nil { - return fmt.Errorf("unable to use proxy %q: %s", c.Config.Proxy, err) - } - } else { - conn, err = dialer.Dial("tcp", c.Server()) - if err != nil { - return fmt.Errorf("unable to connect to %q: %s", c.Server(), err) - } - } - - if c.Config.SSL { - var sslConf *tls.Config - - if c.Config.TLSConfig == nil { - sslConf = &tls.Config{ServerName: c.Config.Server} - } else { - sslConf = c.Config.TLSConfig - } - - tlsConn := tls.Client(conn, sslConf) - if err = tlsConn.Handshake(); err != nil { - return fmt.Errorf("failed handshake during tls conn to %q: %s", c.Server(), err) - } - conn = tlsConn + conn, err := newConn(c.Config, c.Server()) + if err != nil { + return err } c.state.mu.Lock() c.state.conn = conn c.state.mu.Unlock() - c.state.reader = newDecoder(c.state.conn) - c.state.writer = newEncoder(c.state.conn) - // Send a virtual event allowing hooks for successful socket connection. c.Events <- &Event{Command: INITIALIZED, Trailing: c.Server()} @@ -451,8 +380,8 @@ func (c *Client) readLoop(ctx context.Context) { case <-ctx.Done(): return default: - c.state.conn.SetDeadline(time.Now().Add(300 * time.Second)) - event, err = c.state.reader.Decode() + c.state.conn.lconn.SetDeadline(time.Now().Add(300 * time.Second)) + event, err = c.state.conn.Decode() if err != nil { // Attempt a reconnect (if applicable). If it fails, send // the error to c.Config.HandleError to be dealt with, if @@ -569,7 +498,7 @@ func (c *Client) write(event *Event) error { c.debug.Print("> ", StripRaw(event.String())) } - return c.state.writer.Encode(event) + return c.state.conn.Encode(event) } // Uptime is the time at which the client successfully connected to the diff --git a/conn.go b/conn.go index cac1a18..cdb34de 100644 --- a/conn.go +++ b/conn.go @@ -6,8 +6,16 @@ package girc import ( "bufio" + "crypto/tls" + "errors" + "fmt" "io" + "net" + "net/url" "sync" + "time" + + "golang.org/x/net/proxy" ) // Messages are delimited with CR and LF line endings, we're using the last @@ -22,12 +30,89 @@ type ircConn struct { ircEncoder ircDecoder - c io.ReadWriteCloser + lconn net.Conn +} + +func newConn(conf Config, addr string) (*ircConn, error) { + // Sanity check a few options. + if conf.Server == "" { + return nil, errors.New("invalid server specified") + } + + if conf.Port < 21 || conf.Port > 65535 { + return nil, errors.New("invalid port (21-65535)") + } + + if !IsValidNick(conf.Nick) || !IsValidUser(conf.User) { + return nil, errors.New("invalid nickname or user") + } + + var conn net.Conn + var err error + + dialer := &net.Dialer{Timeout: 5 * time.Second} + + if conf.Bind != "" { + var local *net.TCPAddr + local, err = net.ResolveTCPAddr("tcp", conf.Bind+":0") + if err != nil { + return nil, fmt.Errorf("unable to resolve bind address %s: %s", conf.Bind, err) + } + + dialer.LocalAddr = local + } + + if conf.Proxy != "" { + var proxyUri *url.URL + var proxyDialer proxy.Dialer + + proxyUri, err = url.Parse(conf.Proxy) + if err != nil { + return nil, fmt.Errorf("unable to use proxy %q: %s", conf.Proxy, err) + } + + proxyDialer, err = proxy.FromURL(proxyUri, dialer) + if err != nil { + return nil, fmt.Errorf("unable to use proxy %q: %s", conf.Proxy, err) + } + + conn, err = proxyDialer.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("unable to connect to proxy %q: %s", conf.Proxy, err) + } + } else { + conn, err = dialer.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("unable to connect to %q: %s", addr, err) + } + } + + if conf.SSL { + var sslConf *tls.Config + + if conf.TLSConfig == nil { + sslConf = &tls.Config{ServerName: conf.Server} + } else { + sslConf = conf.TLSConfig + } + + tlsConn := tls.Client(conn, sslConf) + if err = tlsConn.Handshake(); err != nil { + return nil, fmt.Errorf("failed handshake during tls conn to %q: %s", addr, err) + } + conn = tlsConn + } + + return &ircConn{ + ircEncoder: ircEncoder{writer: conn}, + ircDecoder: ircDecoder{reader: bufio.NewReader(conn)}, + lconn: conn, + }, nil } // Close closes the underlying ReadWriteCloser. func (c *ircConn) Close() error { - return c.c.Close() + return c.lconn.Close() } // ircDecoder reads Event objects from an input stream. @@ -37,11 +122,6 @@ type ircDecoder struct { mu sync.Mutex } -// newDecoder returns a new Decoder that reads from r. -func newDecoder(r io.Reader) *ircDecoder { - return &ircDecoder{reader: bufio.NewReader(r)} -} - // Decode attempts to read a single Event from the stream, returns non-nil // error if read failed. event may be nil if unparseable. func (dec *ircDecoder) Decode() (event *Event, err error) { @@ -62,11 +142,6 @@ type ircEncoder struct { mu sync.Mutex } -// newEncoder returns a new Encoder that writes to w. -func newEncoder(w io.Writer) *ircEncoder { - return &ircEncoder{writer: w} -} - // Encode writes the IRC encoding of m to the stream. Goroutine safe. // returns non-nil error if the write to the underlying stream stopped early. func (enc *ircEncoder) Encode(e *Event) (err error) { diff --git a/state.go b/state.go index be7b962..f167a88 100644 --- a/state.go +++ b/state.go @@ -6,7 +6,6 @@ package girc import ( "fmt" - "net" "strings" "sync" "time" @@ -19,12 +18,8 @@ type state struct { // corruption. mu sync.RWMutex - // reader is the socket buffer reader from the IRC server. - reader *ircDecoder - // reader is the socket buffer write to the IRC server. - writer *ircEncoder // conn is a net.Conn reference to the IRC server. - conn net.Conn + conn *ircConn // lastWrite is used ot keep track of when we last wrote to the server. lastWrite time.Time