From a8c73f7baf6ca9596fa96e8aa636d34a886c4641 Mon Sep 17 00:00:00 2001 From: bsdelf Date: Mon, 26 Aug 2019 16:54:05 +0800 Subject: [PATCH] Ensure WaitGroup.Done() is always called --- pkg/safe/routine.go | 19 +++++++++--- pkg/safe/routine_test.go | 67 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/pkg/safe/routine.go b/pkg/safe/routine.go index a95c0ecf8..83f33b872 100644 --- a/pkg/safe/routine.go +++ b/pkg/safe/routine.go @@ -59,12 +59,23 @@ func (p *Pool) GoCtx(goroutine routineCtx) { p.routinesCtx = append(p.routinesCtx, goroutine) p.waitGroup.Add(1) Go(func() { + defer p.waitGroup.Done() goroutine(p.ctx) - p.waitGroup.Done() }) p.lock.Unlock() } +// addGo adds a recoverable goroutine, and can be stopped with stop chan +func (p *Pool) addGo(goroutine func(stop chan bool)) { + p.lock.Lock() + newRoutine := routine{ + goroutine: goroutine, + stop: make(chan bool, 1), + } + p.routines = append(p.routines, newRoutine) + p.lock.Unlock() +} + // Go starts a recoverable goroutine, and can be stopped with stop chan func (p *Pool) Go(goroutine func(stop chan bool)) { p.lock.Lock() @@ -75,8 +86,8 @@ func (p *Pool) Go(goroutine func(stop chan bool)) { p.routines = append(p.routines, newRoutine) p.waitGroup.Add(1) Go(func() { + defer p.waitGroup.Done() goroutine(newRoutine.stop) - p.waitGroup.Done() }) p.lock.Unlock() } @@ -112,16 +123,16 @@ func (p *Pool) Start() { p.waitGroup.Add(1) p.routines[i].stop = make(chan bool, 1) Go(func() { + defer p.waitGroup.Done() p.routines[i].goroutine(p.routines[i].stop) - p.waitGroup.Done() }) } for _, routine := range p.routinesCtx { p.waitGroup.Add(1) Go(func() { + defer p.waitGroup.Done() routine(p.ctx) - p.waitGroup.Done() }) } } diff --git a/pkg/safe/routine_test.go b/pkg/safe/routine_test.go index 20d28d4f2..d44d03f58 100644 --- a/pkg/safe/routine_test.go +++ b/pkg/safe/routine_test.go @@ -173,6 +173,73 @@ func TestPoolStartWithStopChan(t *testing.T) { } } +func TestPoolCleanupWithGoPanicking(t *testing.T) { + testRoutine := func(stop chan bool) { + panic("BOOM") + } + + testCtxRoutine := func(ctx context.Context) { + panic("BOOM") + } + + testCases := []struct { + desc string + fn func(*Pool) + }{ + { + desc: "Go()", + fn: func(p *Pool) { + p.Go(testRoutine) + }, + }, + { + desc: "addGo() and Start()", + fn: func(p *Pool) { + p.addGo(testRoutine) + p.Start() + }, + }, + { + desc: "GoCtx()", + fn: func(p *Pool) { + p.GoCtx(testCtxRoutine) + }, + }, + { + desc: "AddGoCtx() and Start()", + fn: func(p *Pool) { + p.AddGoCtx(testCtxRoutine) + p.Start() + }, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + p := NewPool(context.Background()) + + timer := time.NewTimer(500 * time.Millisecond) + defer timer.Stop() + + test.fn(p) + + testDone := make(chan bool, 1) + go func() { + p.Cleanup() + testDone <- true + }() + + select { + case <-timer.C: + t.Fatalf("Pool.Cleanup() did not complete in time with a panicking goroutine") + case <-testDone: + return + } + }) + } +} + func TestGoroutineRecover(t *testing.T) { // if recover fails the test will panic Go(func() {