refactor: make apid stop gracefully and be stopped late
This fixes apid and machined shutdown sequences to do graceful stop of gRPC server with timeout. Also sequences are restructured to stop apid/machined as late as possible allowing access to the node while the long sequence is running (e.g. upgrade or reset). Signed-off-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
This commit is contained in:
parent
0cdf222431
commit
2e790526f7
@ -7,8 +7,12 @@ package apid
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/cosi-project/runtime/api/v1alpha1"
|
||||
"github.com/cosi-project/runtime/pkg/state"
|
||||
@ -46,21 +50,30 @@ func runDebugServer(ctx context.Context) {
|
||||
|
||||
// Main is the entrypoint of apid.
|
||||
func Main() {
|
||||
if err := apidMain(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func apidMain() error {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
defer cancel()
|
||||
|
||||
log.SetFlags(log.Lshortfile | log.Ldate | log.Lmicroseconds | log.Ltime)
|
||||
|
||||
rbacEnabled = flag.Bool("enable-rbac", false, "enable RBAC for Talos API")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
go runDebugServer(context.TODO())
|
||||
go runDebugServer(ctx)
|
||||
|
||||
if err := startup.RandSeed(); err != nil {
|
||||
log.Fatalf("failed to seed RNG: %v", err)
|
||||
return fmt.Errorf("failed to seed RNG: %w", err)
|
||||
}
|
||||
|
||||
runtimeConn, err := grpc.Dial("unix://"+constants.APIRuntimeSocketPath, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to dial runtime connection: %v", err)
|
||||
return fmt.Errorf("failed to dial runtime connection: %w", err)
|
||||
}
|
||||
|
||||
stateClient := v1alpha1.NewStateClient(runtimeConn)
|
||||
@ -68,17 +81,17 @@ func Main() {
|
||||
|
||||
tlsConfig, err := provider.NewTLSConfig(resources)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create remote certificate provider: %+v", err)
|
||||
return fmt.Errorf("failed to create remote certificate provider: %w", err)
|
||||
}
|
||||
|
||||
serverTLSConfig, err := tlsConfig.ServerConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create OS-level TLS configuration: %v", err)
|
||||
return fmt.Errorf("failed to create OS-level TLS configuration: %w", err)
|
||||
}
|
||||
|
||||
clientTLSConfig, err := tlsConfig.ClientConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create client TLS config: %v", err)
|
||||
return fmt.Errorf("failed to create client TLS config: %w", err)
|
||||
}
|
||||
|
||||
backendFactory := apidbackend.NewAPIDFactory(clientTLSConfig)
|
||||
@ -109,9 +122,22 @@ func Main() {
|
||||
// register future pattern: method should have suffix "Stream"
|
||||
router.RegisterStreamedRegex("Stream$")
|
||||
|
||||
var errGroup errgroup.Group
|
||||
networkListener, err := factory.NewListener(
|
||||
factory.Port(constants.ApidPort),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating listner: %w", err)
|
||||
}
|
||||
|
||||
errGroup.Go(func() error {
|
||||
socketListener, err := factory.NewListener(
|
||||
factory.Network("unix"),
|
||||
factory.SocketPath(constants.APISocketPath),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating listner: %w", err)
|
||||
}
|
||||
|
||||
networkServer := func() *grpc.Server {
|
||||
mode := authz.Disabled
|
||||
if *rbacEnabled {
|
||||
mode = authz.Enabled
|
||||
@ -122,9 +148,8 @@ func Main() {
|
||||
Logger: log.New(log.Writer(), "apid/authz/injector/http ", log.Flags()).Printf,
|
||||
}
|
||||
|
||||
return factory.ListenAndServe(
|
||||
return factory.NewServer(
|
||||
router,
|
||||
factory.Port(constants.ApidPort),
|
||||
factory.WithDefaultLog(),
|
||||
factory.ServerOptions(
|
||||
grpc.Creds(
|
||||
@ -140,18 +165,16 @@ func Main() {
|
||||
factory.WithUnaryInterceptor(injector.UnaryInterceptor()),
|
||||
factory.WithStreamInterceptor(injector.StreamInterceptor()),
|
||||
)
|
||||
})
|
||||
}()
|
||||
|
||||
errGroup.Go(func() error {
|
||||
socketServer := func() *grpc.Server {
|
||||
injector := &authz.Injector{
|
||||
Mode: authz.MetadataOnly,
|
||||
Logger: log.New(log.Writer(), "apid/authz/injector/unix ", log.Flags()).Printf,
|
||||
}
|
||||
|
||||
return factory.ListenAndServe(
|
||||
return factory.NewServer(
|
||||
router,
|
||||
factory.Network("unix"),
|
||||
factory.SocketPath(constants.APISocketPath),
|
||||
factory.WithDefaultLog(),
|
||||
factory.ServerOptions(
|
||||
grpc.CustomCodec(proxy.Codec()), //nolint:staticcheck
|
||||
@ -164,9 +187,29 @@ func Main() {
|
||||
factory.WithUnaryInterceptor(injector.UnaryInterceptor()),
|
||||
factory.WithStreamInterceptor(injector.StreamInterceptor()),
|
||||
)
|
||||
}()
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
errGroup.Go(func() error {
|
||||
return networkServer.Serve(networkListener)
|
||||
})
|
||||
|
||||
if err := errGroup.Wait(); err != nil {
|
||||
log.Fatalf("listen: %v", err)
|
||||
}
|
||||
errGroup.Go(func() error {
|
||||
return socketServer.Serve(socketListener)
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
<-ctx.Done()
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
factory.ServerGracefulStop(networkServer, shutdownCtx)
|
||||
factory.ServerGracefulStop(socketServer, shutdownCtx)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
@ -400,7 +400,7 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru
|
||||
LeaveEtcd,
|
||||
).Append(
|
||||
"stopServices",
|
||||
StopServicesForUpgrade,
|
||||
StopServicesEphemeral,
|
||||
).Append(
|
||||
"unmountUser",
|
||||
UnmountUserDisks,
|
||||
@ -421,9 +421,6 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru
|
||||
).Append(
|
||||
"upgrade",
|
||||
Upgrade,
|
||||
).Append(
|
||||
"stopEverything",
|
||||
StopAllServices,
|
||||
).Append(
|
||||
"mountBoot",
|
||||
MountBootPartition,
|
||||
@ -433,6 +430,9 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru
|
||||
).Append(
|
||||
"unmountBoot",
|
||||
UnmountBootPartition,
|
||||
).Append(
|
||||
"stopEverything",
|
||||
StopAllServices,
|
||||
).Append(
|
||||
"reboot",
|
||||
Reboot,
|
||||
@ -453,8 +453,8 @@ func stopAllPhaselist(r runtime.Runtime, enableKexec bool) PhaseList {
|
||||
)
|
||||
default:
|
||||
phases = phases.Append(
|
||||
"stopEverything",
|
||||
StopAllServices,
|
||||
"stopServices",
|
||||
StopServicesEphemeral,
|
||||
).Append(
|
||||
"unmountUser",
|
||||
UnmountUserDisks,
|
||||
@ -481,6 +481,9 @@ func stopAllPhaselist(r runtime.Runtime, enableKexec bool) PhaseList {
|
||||
enableKexec,
|
||||
"unmountBoot",
|
||||
UnmountBootPartition,
|
||||
).Append(
|
||||
"stopEverything",
|
||||
StopAllServices,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -796,11 +796,11 @@ func StartAllServices(seq runtime.Sequence, data interface{}) (runtime.TaskExecu
|
||||
}, "startAllServices"
|
||||
}
|
||||
|
||||
// StopServicesForUpgrade represents the StopServicesForUpgrade task.
|
||||
func StopServicesForUpgrade(seq runtime.Sequence, data interface{}) (runtime.TaskExecutionFunc, string) {
|
||||
// StopServicesEphemeral represents the StopServicesEphemeral task.
|
||||
func StopServicesEphemeral(seq runtime.Sequence, data interface{}) (runtime.TaskExecutionFunc, string) {
|
||||
return func(ctx context.Context, logger *log.Logger, r runtime.Runtime) (err error) {
|
||||
// stopping 'cri' service stops everything which depends on it (kubelet, etcd, ...)
|
||||
return system.Services(nil).StopWithRevDepenencies(ctx, "cri", "udevd")
|
||||
return system.Services(nil).StopWithRevDepenencies(ctx, "cri", "udevd", "trustd")
|
||||
}, "stopServicesForUpgrade"
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
v1alpha1server "github.com/talos-systems/talos/internal/app/machined/internal/server/v1alpha1"
|
||||
"github.com/talos-systems/talos/internal/app/machined/pkg/runtime"
|
||||
@ -134,8 +135,6 @@ func (s *machinedService) Main(ctx context.Context, r runtime.Runtime, logWriter
|
||||
return err
|
||||
}
|
||||
|
||||
defer server.Stop()
|
||||
|
||||
go func() {
|
||||
//nolint:errcheck
|
||||
server.Serve(listener)
|
||||
@ -143,6 +142,11 @@ func (s *machinedService) Main(ctx context.Context, r runtime.Runtime, logWriter
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
factory.ServerGracefulStop(server, shutdownCtx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
package factory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -257,3 +258,22 @@ func ListenAndServe(r Registrator, setters ...Option) (err error) {
|
||||
|
||||
return server.Serve(listener)
|
||||
}
|
||||
|
||||
// ServerGracefulStop the server with a timeout.
|
||||
//
|
||||
// Core gRPC doesn't support timeouts.
|
||||
func ServerGracefulStop(server *grpc.Server, shutdownCtx context.Context) { //nolint:revive
|
||||
stopped := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
server.GracefulStop()
|
||||
close(stopped)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-shutdownCtx.Done():
|
||||
server.Stop()
|
||||
case <-stopped:
|
||||
server.Stop()
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user