From 18122ae73e0489a0497956c6d4621c05c6a77387 Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Mon, 16 Jan 2023 23:45:41 +0400 Subject: [PATCH] fix: service restart (including extension services) Fixes #6707 There was a race condition between different parts of the service code: `Stop` waits for the event which is published before the service is removed from the `running[id]` map, so if one does `Stop` followed by `Start` (this is what `services restart` API does), by the time it goes to `Start` it might be still in the `running[id]` map, so `Start` does nothing. Overall this code should be rewritten and simplified, but for now move out sending these "terminal" events out so that by the time the event is published, the service is stopped and removed from the `running[id]` map. Signed-off-by: Andrey Smirnov --- .../app/machined/pkg/system/service_runner.go | 34 +++--- .../pkg/system/service_runner_test.go | 102 +++++++----------- internal/app/machined/pkg/system/system.go | 33 ++++-- 3 files changed, 83 insertions(+), 86 deletions(-) diff --git a/internal/app/machined/pkg/system/service_runner.go b/internal/app/machined/pkg/system/service_runner.go index 8028860b3..4ad9f5318 100644 --- a/internal/app/machined/pkg/system/service_runner.go +++ b/internal/app/machined/pkg/system/service_runner.go @@ -6,6 +6,7 @@ package system import ( "context" + "errors" "fmt" "log" "sync" @@ -182,12 +183,17 @@ func (svcrunner *ServiceRunner) waitFor(ctx context.Context, condition condition } } -// Start initializes the service and runs it +// ErrSkip is returned by Run when service is skipped. +var ErrSkip = errors.New("service skipped") + +// Run initializes the service and runs it. // -// Start should be run in a goroutine. +// Run returns an error when a service stops. +// +// Run should be run in a goroutine. // //nolint:gocyclo -func (svcrunner *ServiceRunner) Start() { +func (svcrunner *ServiceRunner) Run() error { defer func() { // reset context for the next run svcrunner.ctxMu.Lock() @@ -215,27 +221,21 @@ func (svcrunner *ServiceRunner) Start() { if condition != nil { if err := svcrunner.waitFor(ctx, condition); err != nil { - svcrunner.UpdateState(ctx, events.StateFailed, "Condition failed: %v", err) - - return + return fmt.Errorf("condition failed: %w", err) } } svcrunner.UpdateState(ctx, events.StatePreparing, "Running pre state") if err := svcrunner.service.PreFunc(ctx, svcrunner.runtime); err != nil { - svcrunner.UpdateState(ctx, events.StateFailed, "Failed to run pre stage: %v", err) - - return + return fmt.Errorf("failed to run pre stage: %w", err) } svcrunner.UpdateState(ctx, events.StatePreparing, "Creating service runner") runnr, err := svcrunner.service.Runner(svcrunner.runtime) if err != nil { - svcrunner.UpdateState(ctx, events.StateFailed, "Failed to create runner: %v", err) - - return + return fmt.Errorf("failed to create runner: %w", err) } defer func() { @@ -248,16 +248,14 @@ func (svcrunner *ServiceRunner) Start() { }() if runnr == nil { - svcrunner.UpdateState(ctx, events.StateSkipped, "Service skipped") - - return + return ErrSkip } if err := svcrunner.run(ctx, runnr); err != nil { - svcrunner.UpdateState(ctx, events.StateFailed, "Failed running service: %v", err) - } else { - svcrunner.UpdateState(ctx, events.StateFinished, "Service finished successfully") + return fmt.Errorf("failed running service: %w", err) } + + return nil } //nolint:gocyclo diff --git a/internal/app/machined/pkg/system/service_runner_test.go b/internal/app/machined/pkg/system/service_runner_test.go index a98887dc5..c33e6422d 100644 --- a/internal/app/machined/pkg/system/service_runner_test.go +++ b/internal/app/machined/pkg/system/service_runner_test.go @@ -36,11 +36,10 @@ func (suite *ServiceRunnerSuite) TestFullFlow() { condition: conditions.None(), }, nil) - finished := make(chan struct{}) + errCh := make(chan error) go func() { - defer close(finished) - sr.Start() + errCh <- sr.Run() }() suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error { @@ -53,38 +52,36 @@ func (suite *ServiceRunnerSuite) TestFullFlow() { })) select { - case <-finished: + case <-errCh: suite.Require().Fail("service running should be still running") default: } sr.Shutdown() - <-finished + suite.Assert().NoError(<-errCh) suite.assertStateSequence([]events.ServiceState{ events.StateWaiting, events.StatePreparing, events.StatePreparing, events.StateRunning, - events.StateFinished, }, sr) protoService := sr.AsProto() suite.Assert().Equal("MockRunner", protoService.Id) - suite.Assert().Equal("Finished", protoService.State) + suite.Assert().Equal("Running", protoService.State) suite.Assert().True(protoService.Health.Unknown) - suite.Assert().Len(protoService.Events.Events, 5) + suite.Assert().Len(protoService.Events.Events, 4) } func (suite *ServiceRunnerSuite) TestFullFlowHealthy() { sr := system.NewServiceRunner(&MockHealthcheckedService{}, nil) - finished := make(chan struct{}) + errCh := make(chan error) go func() { - defer close(finished) - sr.Start() + errCh <- sr.Run() }() suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error { @@ -97,21 +94,20 @@ func (suite *ServiceRunnerSuite) TestFullFlowHealthy() { })) select { - case <-finished: + case <-errCh: suite.Require().Fail("service running should be still running") default: } sr.Shutdown() - <-finished + suite.Assert().NoError(<-errCh) suite.assertStateSequence([]events.ServiceState{ events.StatePreparing, events.StatePreparing, events.StateRunning, events.StateRunning, // one more notification when service is healthy - events.StateFinished, }, sr) } @@ -123,11 +119,10 @@ func (suite *ServiceRunnerSuite) TestFullFlowHealthChanges() { } sr := system.NewServiceRunner(&m, nil) - finished := make(chan struct{}) + errCh := make(chan error) go func() { - defer close(finished) - sr.Start() + errCh <- sr.Run() }() suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error { @@ -163,7 +158,7 @@ func (suite *ServiceRunnerSuite) TestFullFlowHealthChanges() { sr.Shutdown() - <-finished + suite.Assert().NoError(<-errCh) suite.assertStateSequence([]events.ServiceState{ events.StateWaiting, @@ -173,7 +168,6 @@ func (suite *ServiceRunnerSuite) TestFullFlowHealthChanges() { events.StateRunning, // initial: healthy events.StateRunning, // not healthy events.StateRunning, // once again healthy - events.StateFinished, }, sr) } @@ -191,11 +185,10 @@ func (suite *ServiceRunnerSuite) TestWaitingDescriptionChange() { condition: conditions.WaitForAll(cond1, cond2), }, nil) - finished := make(chan struct{}) + errCh := make(chan error) go func() { - defer close(finished) - sr.Start() + errCh <- sr.Run() }() suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error { @@ -208,7 +201,7 @@ func (suite *ServiceRunnerSuite) TestWaitingDescriptionChange() { })) select { - case <-finished: + case <-errCh: suite.Require().Fail("service running should be still running") default: } @@ -226,7 +219,7 @@ func (suite *ServiceRunnerSuite) TestWaitingDescriptionChange() { })) select { - case <-finished: + case <-errCh: suite.Require().Fail("service running should be still running") default: } @@ -244,7 +237,7 @@ func (suite *ServiceRunnerSuite) TestWaitingDescriptionChange() { sr.Shutdown() - <-finished + suite.Assert().NoError(<-errCh) suite.assertStateSequence([]events.ServiceState{ events.StateWaiting, @@ -252,7 +245,6 @@ func (suite *ServiceRunnerSuite) TestWaitingDescriptionChange() { events.StatePreparing, events.StatePreparing, events.StateRunning, - events.StateFinished, }, sr) events := sr.GetEventHistory(10000) @@ -265,12 +257,12 @@ func (suite *ServiceRunnerSuite) TestPreStageFail() { preError: errors.New("pre failed"), } sr := system.NewServiceRunner(svc, nil) - sr.Start() + err := sr.Run() suite.assertStateSequence([]events.ServiceState{ events.StatePreparing, - events.StateFailed, }, sr) + suite.Assert().EqualError(err, "failed to run pre stage: pre failed") } func (suite *ServiceRunnerSuite) TestRunnerStageFail() { @@ -278,13 +270,13 @@ func (suite *ServiceRunnerSuite) TestRunnerStageFail() { runnerError: errors.New("runner failed"), } sr := system.NewServiceRunner(svc, nil) - sr.Start() + err := sr.Run() suite.assertStateSequence([]events.ServiceState{ events.StatePreparing, events.StatePreparing, - events.StateFailed, }, sr) + suite.Assert().EqualError(err, "failed to create runner: runner failed") } func (suite *ServiceRunnerSuite) TestRunnerStageSkipped() { @@ -292,13 +284,13 @@ func (suite *ServiceRunnerSuite) TestRunnerStageSkipped() { nilRunner: true, } sr := system.NewServiceRunner(svc, nil) - sr.Start() + err := sr.Run() suite.assertStateSequence([]events.ServiceState{ events.StatePreparing, events.StatePreparing, - events.StateSkipped, }, sr) + suite.Assert().ErrorIs(err, system.ErrSkip) } func (suite *ServiceRunnerSuite) TestAbortOnCondition() { @@ -307,11 +299,10 @@ func (suite *ServiceRunnerSuite) TestAbortOnCondition() { } sr := system.NewServiceRunner(svc, nil) - finished := make(chan struct{}) + errCh := make(chan error) go func() { - defer close(finished) - sr.Start() + errCh <- sr.Run() }() suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error { @@ -324,18 +315,17 @@ func (suite *ServiceRunnerSuite) TestAbortOnCondition() { })) select { - case <-finished: + case <-errCh: suite.Require().Fail("service running should be still running") default: } sr.Shutdown() - <-finished + suite.Assert().EqualError(<-errCh, "condition failed: context canceled") suite.assertStateSequence([]events.ServiceState{ events.StateWaiting, - events.StateFailed, }, sr) } @@ -346,23 +336,21 @@ func (suite *ServiceRunnerSuite) TestPostStateFail() { } sr := system.NewServiceRunner(svc, nil) - finished := make(chan struct{}) + errCh := make(chan error) go func() { - defer close(finished) - sr.Start() + errCh <- sr.Run() }() sr.Shutdown() - <-finished + suite.Assert().NoError(<-errCh) suite.assertStateSequence([]events.ServiceState{ events.StateWaiting, events.StatePreparing, events.StatePreparing, events.StateRunning, - events.StateFinished, events.StateFailed, }, sr) } @@ -372,22 +360,20 @@ func (suite *ServiceRunnerSuite) TestRunFail() { svc := &MockService{runner: runner} sr := system.NewServiceRunner(svc, nil) - finished := make(chan struct{}) + errCh := make(chan error) go func() { - defer close(finished) - sr.Start() + errCh <- sr.Run() }() runner.exitCh <- errors.New("run failed") - <-finished + suite.Assert().EqualError(<-errCh, "failed running service: error running service: run failed") suite.assertStateSequence([]events.ServiceState{ events.StatePreparing, events.StatePreparing, events.StateRunning, - events.StateFailed, }, sr) } @@ -396,11 +382,10 @@ func (suite *ServiceRunnerSuite) TestFullFlowRestart() { condition: conditions.None(), }, nil) - finished := make(chan struct{}) + errCh := make(chan error) go func() { - defer close(finished) - sr.Start() + errCh <- sr.Run() }() suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error { @@ -413,20 +398,17 @@ func (suite *ServiceRunnerSuite) TestFullFlowRestart() { })) select { - case <-finished: + case <-errCh: suite.Require().Fail("service running should be still running") default: } sr.Shutdown() - <-finished - - finished = make(chan struct{}) + suite.Assert().NoError(<-errCh) go func() { - defer close(finished) - sr.Start() + errCh <- sr.Run() }() suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error { @@ -439,26 +421,24 @@ func (suite *ServiceRunnerSuite) TestFullFlowRestart() { })) select { - case <-finished: + case <-errCh: suite.Require().Fail("service running should be still running") default: } sr.Shutdown() - <-finished + suite.Assert().NoError(<-errCh) suite.assertStateSequence([]events.ServiceState{ events.StateWaiting, events.StatePreparing, events.StatePreparing, events.StateRunning, - events.StateFinished, events.StateWaiting, events.StatePreparing, events.StatePreparing, events.StateRunning, - events.StateFinished, }, sr) } diff --git a/internal/app/machined/pkg/system/system.go b/internal/app/machined/pkg/system/system.go index 3733707d8..e5a544d96 100644 --- a/internal/app/machined/pkg/system/system.go +++ b/internal/app/machined/pkg/system/system.go @@ -6,9 +6,11 @@ package system import ( "context" + "errors" "fmt" "log" "sort" + "strings" "sync" "time" @@ -17,6 +19,7 @@ import ( "github.com/siderolabs/gen/slices" "github.com/siderolabs/talos/internal/app/machined/pkg/runtime" + "github.com/siderolabs/talos/internal/app/machined/pkg/system/events" "github.com/siderolabs/talos/pkg/conditions" ) @@ -162,14 +165,30 @@ func (s *singleton) Start(serviceIDs ...string) error { s.wg.Add(1) go func(id string, svcrunner *ServiceRunner) { - defer func() { - s.runningMu.Lock() - delete(s.running, id) - s.runningMu.Unlock() - }() - defer s.wg.Done() + err := func() error { + defer func() { + s.runningMu.Lock() + delete(s.running, id) + s.runningMu.Unlock() + }() + defer s.wg.Done() - svcrunner.Start() + return svcrunner.Run() + }() + + switch { + case err == nil: + svcrunner.UpdateState(context.Background(), events.StateFinished, "Service finished successfully") + case errors.Is(err, ErrSkip): + svcrunner.UpdateState(context.Background(), events.StateSkipped, "Service skipped") + default: + msg := err.Error() + if len(msg) > 0 { + msg = strings.ToUpper(msg[:1]) + msg[1:] + } + + svcrunner.UpdateState(context.Background(), events.StateFailed, msg) + } }(id, svcrunner) }