fix: service restart (including extension services)

Fixes #6707

There was a race condition between different parts of the service code:
`Stop` waits for the event which is published before the service is
removed from the `running[id]` map, so if one does `Stop` followed by
`Start` (this is what `services restart` API does), by the time it goes
to `Start` it might be still in the `running[id]` map, so `Start` does
nothing.

Overall this code should be rewritten and simplified, but for now move
out sending these "terminal" events out so that by the time the event is
published, the service is stopped and removed from the `running[id]`
map.

Signed-off-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
This commit is contained in:
Andrey Smirnov 2023-01-16 23:45:41 +04:00
parent 680fd5e452
commit 18122ae73e
No known key found for this signature in database
GPG Key ID: 7B26396447AB6DFD
3 changed files with 83 additions and 86 deletions

View File

@ -6,6 +6,7 @@ package system
import (
"context"
"errors"
"fmt"
"log"
"sync"
@ -182,12 +183,17 @@ func (svcrunner *ServiceRunner) waitFor(ctx context.Context, condition condition
}
}
// Start initializes the service and runs it
// ErrSkip is returned by Run when service is skipped.
var ErrSkip = errors.New("service skipped")
// Run initializes the service and runs it.
//
// Start should be run in a goroutine.
// Run returns an error when a service stops.
//
// Run should be run in a goroutine.
//
//nolint:gocyclo
func (svcrunner *ServiceRunner) Start() {
func (svcrunner *ServiceRunner) Run() error {
defer func() {
// reset context for the next run
svcrunner.ctxMu.Lock()
@ -215,27 +221,21 @@ func (svcrunner *ServiceRunner) Start() {
if condition != nil {
if err := svcrunner.waitFor(ctx, condition); err != nil {
svcrunner.UpdateState(ctx, events.StateFailed, "Condition failed: %v", err)
return
return fmt.Errorf("condition failed: %w", err)
}
}
svcrunner.UpdateState(ctx, events.StatePreparing, "Running pre state")
if err := svcrunner.service.PreFunc(ctx, svcrunner.runtime); err != nil {
svcrunner.UpdateState(ctx, events.StateFailed, "Failed to run pre stage: %v", err)
return
return fmt.Errorf("failed to run pre stage: %w", err)
}
svcrunner.UpdateState(ctx, events.StatePreparing, "Creating service runner")
runnr, err := svcrunner.service.Runner(svcrunner.runtime)
if err != nil {
svcrunner.UpdateState(ctx, events.StateFailed, "Failed to create runner: %v", err)
return
return fmt.Errorf("failed to create runner: %w", err)
}
defer func() {
@ -248,16 +248,14 @@ func (svcrunner *ServiceRunner) Start() {
}()
if runnr == nil {
svcrunner.UpdateState(ctx, events.StateSkipped, "Service skipped")
return
return ErrSkip
}
if err := svcrunner.run(ctx, runnr); err != nil {
svcrunner.UpdateState(ctx, events.StateFailed, "Failed running service: %v", err)
} else {
svcrunner.UpdateState(ctx, events.StateFinished, "Service finished successfully")
return fmt.Errorf("failed running service: %w", err)
}
return nil
}
//nolint:gocyclo

View File

@ -36,11 +36,10 @@ func (suite *ServiceRunnerSuite) TestFullFlow() {
condition: conditions.None(),
}, nil)
finished := make(chan struct{})
errCh := make(chan error)
go func() {
defer close(finished)
sr.Start()
errCh <- sr.Run()
}()
suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error {
@ -53,38 +52,36 @@ func (suite *ServiceRunnerSuite) TestFullFlow() {
}))
select {
case <-finished:
case <-errCh:
suite.Require().Fail("service running should be still running")
default:
}
sr.Shutdown()
<-finished
suite.Assert().NoError(<-errCh)
suite.assertStateSequence([]events.ServiceState{
events.StateWaiting,
events.StatePreparing,
events.StatePreparing,
events.StateRunning,
events.StateFinished,
}, sr)
protoService := sr.AsProto()
suite.Assert().Equal("MockRunner", protoService.Id)
suite.Assert().Equal("Finished", protoService.State)
suite.Assert().Equal("Running", protoService.State)
suite.Assert().True(protoService.Health.Unknown)
suite.Assert().Len(protoService.Events.Events, 5)
suite.Assert().Len(protoService.Events.Events, 4)
}
func (suite *ServiceRunnerSuite) TestFullFlowHealthy() {
sr := system.NewServiceRunner(&MockHealthcheckedService{}, nil)
finished := make(chan struct{})
errCh := make(chan error)
go func() {
defer close(finished)
sr.Start()
errCh <- sr.Run()
}()
suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error {
@ -97,21 +94,20 @@ func (suite *ServiceRunnerSuite) TestFullFlowHealthy() {
}))
select {
case <-finished:
case <-errCh:
suite.Require().Fail("service running should be still running")
default:
}
sr.Shutdown()
<-finished
suite.Assert().NoError(<-errCh)
suite.assertStateSequence([]events.ServiceState{
events.StatePreparing,
events.StatePreparing,
events.StateRunning,
events.StateRunning, // one more notification when service is healthy
events.StateFinished,
}, sr)
}
@ -123,11 +119,10 @@ func (suite *ServiceRunnerSuite) TestFullFlowHealthChanges() {
}
sr := system.NewServiceRunner(&m, nil)
finished := make(chan struct{})
errCh := make(chan error)
go func() {
defer close(finished)
sr.Start()
errCh <- sr.Run()
}()
suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error {
@ -163,7 +158,7 @@ func (suite *ServiceRunnerSuite) TestFullFlowHealthChanges() {
sr.Shutdown()
<-finished
suite.Assert().NoError(<-errCh)
suite.assertStateSequence([]events.ServiceState{
events.StateWaiting,
@ -173,7 +168,6 @@ func (suite *ServiceRunnerSuite) TestFullFlowHealthChanges() {
events.StateRunning, // initial: healthy
events.StateRunning, // not healthy
events.StateRunning, // once again healthy
events.StateFinished,
}, sr)
}
@ -191,11 +185,10 @@ func (suite *ServiceRunnerSuite) TestWaitingDescriptionChange() {
condition: conditions.WaitForAll(cond1, cond2),
}, nil)
finished := make(chan struct{})
errCh := make(chan error)
go func() {
defer close(finished)
sr.Start()
errCh <- sr.Run()
}()
suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error {
@ -208,7 +201,7 @@ func (suite *ServiceRunnerSuite) TestWaitingDescriptionChange() {
}))
select {
case <-finished:
case <-errCh:
suite.Require().Fail("service running should be still running")
default:
}
@ -226,7 +219,7 @@ func (suite *ServiceRunnerSuite) TestWaitingDescriptionChange() {
}))
select {
case <-finished:
case <-errCh:
suite.Require().Fail("service running should be still running")
default:
}
@ -244,7 +237,7 @@ func (suite *ServiceRunnerSuite) TestWaitingDescriptionChange() {
sr.Shutdown()
<-finished
suite.Assert().NoError(<-errCh)
suite.assertStateSequence([]events.ServiceState{
events.StateWaiting,
@ -252,7 +245,6 @@ func (suite *ServiceRunnerSuite) TestWaitingDescriptionChange() {
events.StatePreparing,
events.StatePreparing,
events.StateRunning,
events.StateFinished,
}, sr)
events := sr.GetEventHistory(10000)
@ -265,12 +257,12 @@ func (suite *ServiceRunnerSuite) TestPreStageFail() {
preError: errors.New("pre failed"),
}
sr := system.NewServiceRunner(svc, nil)
sr.Start()
err := sr.Run()
suite.assertStateSequence([]events.ServiceState{
events.StatePreparing,
events.StateFailed,
}, sr)
suite.Assert().EqualError(err, "failed to run pre stage: pre failed")
}
func (suite *ServiceRunnerSuite) TestRunnerStageFail() {
@ -278,13 +270,13 @@ func (suite *ServiceRunnerSuite) TestRunnerStageFail() {
runnerError: errors.New("runner failed"),
}
sr := system.NewServiceRunner(svc, nil)
sr.Start()
err := sr.Run()
suite.assertStateSequence([]events.ServiceState{
events.StatePreparing,
events.StatePreparing,
events.StateFailed,
}, sr)
suite.Assert().EqualError(err, "failed to create runner: runner failed")
}
func (suite *ServiceRunnerSuite) TestRunnerStageSkipped() {
@ -292,13 +284,13 @@ func (suite *ServiceRunnerSuite) TestRunnerStageSkipped() {
nilRunner: true,
}
sr := system.NewServiceRunner(svc, nil)
sr.Start()
err := sr.Run()
suite.assertStateSequence([]events.ServiceState{
events.StatePreparing,
events.StatePreparing,
events.StateSkipped,
}, sr)
suite.Assert().ErrorIs(err, system.ErrSkip)
}
func (suite *ServiceRunnerSuite) TestAbortOnCondition() {
@ -307,11 +299,10 @@ func (suite *ServiceRunnerSuite) TestAbortOnCondition() {
}
sr := system.NewServiceRunner(svc, nil)
finished := make(chan struct{})
errCh := make(chan error)
go func() {
defer close(finished)
sr.Start()
errCh <- sr.Run()
}()
suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error {
@ -324,18 +315,17 @@ func (suite *ServiceRunnerSuite) TestAbortOnCondition() {
}))
select {
case <-finished:
case <-errCh:
suite.Require().Fail("service running should be still running")
default:
}
sr.Shutdown()
<-finished
suite.Assert().EqualError(<-errCh, "condition failed: context canceled")
suite.assertStateSequence([]events.ServiceState{
events.StateWaiting,
events.StateFailed,
}, sr)
}
@ -346,23 +336,21 @@ func (suite *ServiceRunnerSuite) TestPostStateFail() {
}
sr := system.NewServiceRunner(svc, nil)
finished := make(chan struct{})
errCh := make(chan error)
go func() {
defer close(finished)
sr.Start()
errCh <- sr.Run()
}()
sr.Shutdown()
<-finished
suite.Assert().NoError(<-errCh)
suite.assertStateSequence([]events.ServiceState{
events.StateWaiting,
events.StatePreparing,
events.StatePreparing,
events.StateRunning,
events.StateFinished,
events.StateFailed,
}, sr)
}
@ -372,22 +360,20 @@ func (suite *ServiceRunnerSuite) TestRunFail() {
svc := &MockService{runner: runner}
sr := system.NewServiceRunner(svc, nil)
finished := make(chan struct{})
errCh := make(chan error)
go func() {
defer close(finished)
sr.Start()
errCh <- sr.Run()
}()
runner.exitCh <- errors.New("run failed")
<-finished
suite.Assert().EqualError(<-errCh, "failed running service: error running service: run failed")
suite.assertStateSequence([]events.ServiceState{
events.StatePreparing,
events.StatePreparing,
events.StateRunning,
events.StateFailed,
}, sr)
}
@ -396,11 +382,10 @@ func (suite *ServiceRunnerSuite) TestFullFlowRestart() {
condition: conditions.None(),
}, nil)
finished := make(chan struct{})
errCh := make(chan error)
go func() {
defer close(finished)
sr.Start()
errCh <- sr.Run()
}()
suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error {
@ -413,20 +398,17 @@ func (suite *ServiceRunnerSuite) TestFullFlowRestart() {
}))
select {
case <-finished:
case <-errCh:
suite.Require().Fail("service running should be still running")
default:
}
sr.Shutdown()
<-finished
finished = make(chan struct{})
suite.Assert().NoError(<-errCh)
go func() {
defer close(finished)
sr.Start()
errCh <- sr.Run()
}()
suite.Require().NoError(retry.Constant(time.Minute, retry.WithUnits(10*time.Millisecond)).Retry(func() error {
@ -439,26 +421,24 @@ func (suite *ServiceRunnerSuite) TestFullFlowRestart() {
}))
select {
case <-finished:
case <-errCh:
suite.Require().Fail("service running should be still running")
default:
}
sr.Shutdown()
<-finished
suite.Assert().NoError(<-errCh)
suite.assertStateSequence([]events.ServiceState{
events.StateWaiting,
events.StatePreparing,
events.StatePreparing,
events.StateRunning,
events.StateFinished,
events.StateWaiting,
events.StatePreparing,
events.StatePreparing,
events.StateRunning,
events.StateFinished,
}, sr)
}

View File

@ -6,9 +6,11 @@ package system
import (
"context"
"errors"
"fmt"
"log"
"sort"
"strings"
"sync"
"time"
@ -17,6 +19,7 @@ import (
"github.com/siderolabs/gen/slices"
"github.com/siderolabs/talos/internal/app/machined/pkg/runtime"
"github.com/siderolabs/talos/internal/app/machined/pkg/system/events"
"github.com/siderolabs/talos/pkg/conditions"
)
@ -162,6 +165,7 @@ func (s *singleton) Start(serviceIDs ...string) error {
s.wg.Add(1)
go func(id string, svcrunner *ServiceRunner) {
err := func() error {
defer func() {
s.runningMu.Lock()
delete(s.running, id)
@ -169,7 +173,22 @@ func (s *singleton) Start(serviceIDs ...string) error {
}()
defer s.wg.Done()
svcrunner.Start()
return svcrunner.Run()
}()
switch {
case err == nil:
svcrunner.UpdateState(context.Background(), events.StateFinished, "Service finished successfully")
case errors.Is(err, ErrSkip):
svcrunner.UpdateState(context.Background(), events.StateSkipped, "Service skipped")
default:
msg := err.Error()
if len(msg) > 0 {
msg = strings.ToUpper(msg[:1]) + msg[1:]
}
svcrunner.UpdateState(context.Background(), events.StateFailed, msg)
}
}(id, svcrunner)
}