use already defined context for cancellation

This commit is contained in:
Liam Stanley 2017-07-18 18:16:54 -04:00
parent 76843d6b84
commit 7ad774d275
2 changed files with 12 additions and 19 deletions

@ -283,7 +283,7 @@ func (c *Client) Close() {
_ = c.conn.Close() _ = c.conn.Close()
} }
func (c *Client) execLoop(done chan struct{}, wg *sync.WaitGroup) { func (c *Client) execLoop(ctx context.Context, wg *sync.WaitGroup) {
c.debug.Print("starting execLoop") c.debug.Print("starting execLoop")
defer c.debug.Print("closing execLoop") defer c.debug.Print("closing execLoop")
@ -291,7 +291,7 @@ func (c *Client) execLoop(done chan struct{}, wg *sync.WaitGroup) {
for { for {
select { select {
case <-done: case <-ctx.Done():
// We've been told to exit, however we shouldn't bail on the // We've been told to exit, however we shouldn't bail on the
// current events in the queue that should be processed, as one // current events in the queue that should be processed, as one
// may want to handle an ERROR, QUIT, etc. // may want to handle an ERROR, QUIT, etc.

27
conn.go

@ -288,21 +288,15 @@ func (c *Client) internalConnect(mock net.Conn) error {
// Start read loop to process messages from the server. // Start read loop to process messages from the server.
errs := make(chan error, 3) errs := make(chan error, 3)
done := make(chan struct{}, 4)
var wg sync.WaitGroup var wg sync.WaitGroup
// 4 being the number of goroutines we need to finish when this function // 4 being the number of goroutines we need to finish when this function
// returns. // returns.
wg.Add(4) wg.Add(4)
go c.execLoop(done, &wg) go c.execLoop(ctx, &wg)
go c.readLoop(errs, done, &wg) go c.readLoop(errs, ctx, &wg)
go c.sendLoop(errs, done, &wg) go c.sendLoop(errs, ctx, &wg)
go c.pingLoop(errs, ctx, &wg)
if mock == nil {
go c.pingLoop(errs, done, &wg)
} else {
go c.pingLoop(errs, done, &wg)
}
// Passwords first. // Passwords first.
if c.Config.ServerPass != "" { if c.Config.ServerPass != "" {
@ -344,7 +338,6 @@ func (c *Client) internalConnect(mock net.Conn) error {
// Once we have our error/result, let all other functions know we're done. // Once we have our error/result, let all other functions know we're done.
c.debug.Print("waiting for all routines to finish") c.debug.Print("waiting for all routines to finish")
close(done)
// Wait for all goroutines to finish. // Wait for all goroutines to finish.
wg.Wait() wg.Wait()
@ -365,7 +358,7 @@ func (c *Client) internalConnect(mock net.Conn) error {
// readLoop sets a timeout of 300 seconds, and then attempts to read from the // readLoop sets a timeout of 300 seconds, and then attempts to read from the
// IRC server. If there is an error, it calls Reconnect. // IRC server. If there is an error, it calls Reconnect.
func (c *Client) readLoop(errs chan error, done chan struct{}, wg *sync.WaitGroup) { func (c *Client) readLoop(errs chan error, ctx context.Context, wg *sync.WaitGroup) {
c.debug.Print("starting readLoop") c.debug.Print("starting readLoop")
defer c.debug.Print("closing readLoop") defer c.debug.Print("closing readLoop")
@ -374,7 +367,7 @@ func (c *Client) readLoop(errs chan error, done chan struct{}, wg *sync.WaitGrou
for { for {
select { select {
case <-done: case <-ctx.Done():
wg.Done() wg.Done()
return return
default: default:
@ -432,7 +425,7 @@ func (c *ircConn) rate(chars int) time.Duration {
return 0 return 0
} }
func (c *Client) sendLoop(errs chan error, done chan struct{}, wg *sync.WaitGroup) { func (c *Client) sendLoop(errs chan error, ctx context.Context, wg *sync.WaitGroup) {
c.debug.Print("starting sendLoop") c.debug.Print("starting sendLoop")
defer c.debug.Print("closing sendLoop") defer c.debug.Print("closing sendLoop")
@ -495,7 +488,7 @@ func (c *Client) sendLoop(errs chan error, done chan struct{}, wg *sync.WaitGrou
wg.Done() wg.Done()
return return
} }
case <-done: case <-ctx.Done():
wg.Done() wg.Done()
return return
} }
@ -528,7 +521,7 @@ type ErrTimedOut struct {
func (ErrTimedOut) Error() string { return "timed out during ping to server" } func (ErrTimedOut) Error() string { return "timed out during ping to server" }
func (c *Client) pingLoop(errs chan error, done chan struct{}, wg *sync.WaitGroup) { func (c *Client) pingLoop(errs chan error, ctx context.Context, wg *sync.WaitGroup) {
// Don't run the pingLoop if they want to disable it. // Don't run the pingLoop if they want to disable it.
if c.Config.PingDelay <= 0 { if c.Config.PingDelay <= 0 {
wg.Done() wg.Done()
@ -584,7 +577,7 @@ func (c *Client) pingLoop(errs chan error, done chan struct{}, wg *sync.WaitGrou
c.conn.mu.Unlock() c.conn.mu.Unlock()
c.Cmd.Ping(fmt.Sprintf("%d", time.Now().UnixNano())) c.Cmd.Ping(fmt.Sprintf("%d", time.Now().UnixNano()))
case <-done: case <-ctx.Done():
wg.Done() wg.Done()
return return
} }