diff --git a/prompt.go b/prompt.go index 6f72944..ac86452 100644 --- a/prompt.go +++ b/prompt.go @@ -2,6 +2,7 @@ package prompt import ( "bytes" + "context" "io/ioutil" "log" "os" @@ -29,6 +30,8 @@ type Prompt struct { keyBindings []KeyBind ASCIICodeBindings []ASCIICodeBind keyBindMode KeyBindMode + ctx context.Context + cancel context.CancelFunc } // Exec is the struct contains user input context. @@ -53,27 +56,23 @@ func (p *Prompt) Run() { p.renderer.Render(p.buf, p.completion) - bufCh := make(chan []byte, 128) - stopReadBufCh := make(chan struct{}) - go p.readBuffer(bufCh, stopReadBufCh) + bufchan := make(chan []byte, 128) + go p.readBuffer(p.ctx, bufchan) exitCh := make(chan int) - winSizeCh := make(chan *WinSize) - stopHandleSignalCh := make(chan struct{}) - go p.handleSignals(exitCh, winSizeCh, stopHandleSignalCh) + winchan := make(chan *WinSize) + go p.handleSignals(p.ctx, p.cancel, winchan) for { select { - case b := <-bufCh: + case b := <-bufchan: if shouldExit, e := p.feed(b); shouldExit { p.renderer.BreakLine(p.buf) - stopReadBufCh <- struct{}{} - stopHandleSignalCh <- struct{}{} + p.cancel() return } else if e != nil { // Stop goroutine to run readBuffer function - stopReadBufCh <- struct{}{} - stopHandleSignalCh <- struct{}{} + p.cancel() // Unset raw mode // Reset to Blocking mode because returned EAGAIN when still set non-blocking mode. @@ -85,13 +84,16 @@ func (p *Prompt) Run() { // Set raw mode p.in.Setup() - go p.readBuffer(bufCh, stopReadBufCh) - go p.handleSignals(exitCh, winSizeCh, stopHandleSignalCh) + ctx, cancel := context.WithCancel(context.Background()) + p.ctx = ctx + p.cancel = cancel + go p.readBuffer(p.ctx, bufchan) + go p.handleSignals(p.ctx, p.cancel, winchan) } else { p.completion.Update(*p.buf.Document()) p.renderer.Render(p.buf, p.completion) } - case w := <-winSizeCh: + case w := <-winchan: p.renderer.UpdateWinSize(w) p.renderer.Render(p.buf, p.completion) case code := <-exitCh: @@ -233,20 +235,17 @@ func (p *Prompt) Input() string { defer p.tearDown() p.renderer.Render(p.buf, p.completion) - bufCh := make(chan []byte, 128) - stopReadBufCh := make(chan struct{}) - go p.readBuffer(bufCh, stopReadBufCh) + bufchan := make(chan []byte, 128) + go p.readBuffer(p.ctx, bufchan) for { select { - case b := <-bufCh: + case b := <-bufchan: if shouldExit, e := p.feed(b); shouldExit { p.renderer.BreakLine(p.buf) - stopReadBufCh <- struct{}{} + p.cancel() return "" } else if e != nil { - // Stop goroutine to run readBuffer function - stopReadBufCh <- struct{}{} return e.input } else { p.completion.Update(*p.buf.Document()) @@ -258,11 +257,11 @@ func (p *Prompt) Input() string { } } -func (p *Prompt) readBuffer(bufCh chan []byte, stopCh chan struct{}) { +func (p *Prompt) readBuffer(ctx context.Context, bufCh chan []byte) { log.Printf("[INFO] readBuffer start") for { select { - case <-stopCh: + case <-ctx.Done(): log.Print("[INFO] stop readBuffer") return default: @@ -278,9 +277,14 @@ func (p *Prompt) setUp() { p.in.Setup() p.renderer.Setup() p.renderer.UpdateWinSize(p.in.GetWinSize()) + + ctx, cancel := context.WithCancel(context.Background()) + p.ctx = ctx + p.cancel = cancel } func (p *Prompt) tearDown() { + p.cancel() p.in.TearDown() p.renderer.TearDown() } diff --git a/signal_posix.go b/signal_posix.go index cff1327..15020a5 100644 --- a/signal_posix.go +++ b/signal_posix.go @@ -3,17 +3,18 @@ package prompt import ( + "context" "log" "os" "os/signal" "syscall" ) -func (p *Prompt) handleSignals(exitCh chan int, winSizeCh chan *WinSize, stop chan struct{}) { +func (p *Prompt) handleSignals(ctx context.Context, cancel context.CancelFunc, winSizeCh chan *WinSize) { in := p.in - sigCh := make(chan os.Signal, 1) + sigchan := make(chan os.Signal, 1) signal.Notify( - sigCh, + sigchan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, @@ -22,22 +23,22 @@ func (p *Prompt) handleSignals(exitCh chan int, winSizeCh chan *WinSize, stop ch for { select { - case <-stop: + case <-ctx.Done(): log.Println("[INFO] stop handleSignals") return - case s := <-sigCh: + case s := <-sigchan: switch s { case syscall.SIGINT: // kill -SIGINT XXXX or Ctrl+c log.Println("[SIGNAL] Catch SIGINT") - exitCh <- 0 + cancel() case syscall.SIGTERM: // kill -SIGTERM XXXX log.Println("[SIGNAL] Catch SIGTERM") - exitCh <- 1 + cancel() case syscall.SIGQUIT: // kill -SIGQUIT XXXX log.Println("[SIGNAL] Catch SIGQUIT") - exitCh <- 0 + cancel() case syscall.SIGWINCH: log.Println("[SIGNAL] Catch SIGWINCH") diff --git a/signal_windows.go b/signal_windows.go index 5c34a63..c1d29ef 100644 --- a/signal_windows.go +++ b/signal_windows.go @@ -3,13 +3,14 @@ package prompt import ( + "context" "log" "os" "os/signal" "syscall" ) -func (p *Prompt) handleSignals(exitCh chan int, winSizeCh chan *WinSize, stop chan struct{}) { +func (p *Prompt) handleSignals(tx context.Context, cancel context.CancelFunc, winSizeCh chan *WinSize) { sigCh := make(chan os.Signal, 1) signal.Notify( sigCh, @@ -20,23 +21,18 @@ func (p *Prompt) handleSignals(exitCh chan int, winSizeCh chan *WinSize, stop ch for { select { - case <-stop: - log.Println("[INFO] stop handleSignals") + case <-ctx.Done(): return case s := <-sigCh: switch s { - case syscall.SIGINT: // kill -SIGINT XXXX or Ctrl+c - log.Println("[SIGNAL] Catch SIGINT") - exitCh <- 0 + cancel() case syscall.SIGTERM: // kill -SIGTERM XXXX log.Println("[SIGNAL] Catch SIGTERM") - exitCh <- 1 case syscall.SIGQUIT: // kill -SIGQUIT XXXX - log.Println("[SIGNAL] Catch SIGQUIT") - exitCh <- 0 + cancel() } } }