refactor: make apid stop gracefully and be stopped late

This fixes apid and machined shutdown sequences to do graceful stop of
gRPC server with timeout.

Also sequences are restructured to stop apid/machined as late as
possible allowing access to the node while the long sequence is running
(e.g. upgrade or reset).

Signed-off-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
This commit is contained in:
Andrey Smirnov 2022-07-28 21:17:57 +04:00
parent 0cdf222431
commit 2e790526f7
No known key found for this signature in database
GPG Key ID: 7B26396447AB6DFD
5 changed files with 99 additions and 29 deletions

View File

@ -7,8 +7,12 @@ package apid
import (
"context"
"flag"
"fmt"
"log"
"os/signal"
"regexp"
"syscall"
"time"
"github.com/cosi-project/runtime/api/v1alpha1"
"github.com/cosi-project/runtime/pkg/state"
@ -46,21 +50,30 @@ func runDebugServer(ctx context.Context) {
// Main is the entrypoint of apid.
func Main() {
if err := apidMain(); err != nil {
log.Fatal(err)
}
}
func apidMain() error {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer cancel()
log.SetFlags(log.Lshortfile | log.Ldate | log.Lmicroseconds | log.Ltime)
rbacEnabled = flag.Bool("enable-rbac", false, "enable RBAC for Talos API")
flag.Parse()
go runDebugServer(context.TODO())
go runDebugServer(ctx)
if err := startup.RandSeed(); err != nil {
log.Fatalf("failed to seed RNG: %v", err)
return fmt.Errorf("failed to seed RNG: %w", err)
}
runtimeConn, err := grpc.Dial("unix://"+constants.APIRuntimeSocketPath, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
log.Fatalf("failed to dial runtime connection: %v", err)
return fmt.Errorf("failed to dial runtime connection: %w", err)
}
stateClient := v1alpha1.NewStateClient(runtimeConn)
@ -68,17 +81,17 @@ func Main() {
tlsConfig, err := provider.NewTLSConfig(resources)
if err != nil {
log.Fatalf("failed to create remote certificate provider: %+v", err)
return fmt.Errorf("failed to create remote certificate provider: %w", err)
}
serverTLSConfig, err := tlsConfig.ServerConfig()
if err != nil {
log.Fatalf("failed to create OS-level TLS configuration: %v", err)
return fmt.Errorf("failed to create OS-level TLS configuration: %w", err)
}
clientTLSConfig, err := tlsConfig.ClientConfig()
if err != nil {
log.Fatalf("failed to create client TLS config: %v", err)
return fmt.Errorf("failed to create client TLS config: %w", err)
}
backendFactory := apidbackend.NewAPIDFactory(clientTLSConfig)
@ -109,9 +122,22 @@ func Main() {
// register future pattern: method should have suffix "Stream"
router.RegisterStreamedRegex("Stream$")
var errGroup errgroup.Group
networkListener, err := factory.NewListener(
factory.Port(constants.ApidPort),
)
if err != nil {
return fmt.Errorf("error creating listner: %w", err)
}
errGroup.Go(func() error {
socketListener, err := factory.NewListener(
factory.Network("unix"),
factory.SocketPath(constants.APISocketPath),
)
if err != nil {
return fmt.Errorf("error creating listner: %w", err)
}
networkServer := func() *grpc.Server {
mode := authz.Disabled
if *rbacEnabled {
mode = authz.Enabled
@ -122,9 +148,8 @@ func Main() {
Logger: log.New(log.Writer(), "apid/authz/injector/http ", log.Flags()).Printf,
}
return factory.ListenAndServe(
return factory.NewServer(
router,
factory.Port(constants.ApidPort),
factory.WithDefaultLog(),
factory.ServerOptions(
grpc.Creds(
@ -140,18 +165,16 @@ func Main() {
factory.WithUnaryInterceptor(injector.UnaryInterceptor()),
factory.WithStreamInterceptor(injector.StreamInterceptor()),
)
})
}()
errGroup.Go(func() error {
socketServer := func() *grpc.Server {
injector := &authz.Injector{
Mode: authz.MetadataOnly,
Logger: log.New(log.Writer(), "apid/authz/injector/unix ", log.Flags()).Printf,
}
return factory.ListenAndServe(
return factory.NewServer(
router,
factory.Network("unix"),
factory.SocketPath(constants.APISocketPath),
factory.WithDefaultLog(),
factory.ServerOptions(
grpc.CustomCodec(proxy.Codec()), //nolint:staticcheck
@ -164,9 +187,29 @@ func Main() {
factory.WithUnaryInterceptor(injector.UnaryInterceptor()),
factory.WithStreamInterceptor(injector.StreamInterceptor()),
)
}()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return networkServer.Serve(networkListener)
})
if err := errGroup.Wait(); err != nil {
log.Fatalf("listen: %v", err)
}
errGroup.Go(func() error {
return socketServer.Serve(socketListener)
})
errGroup.Go(func() error {
<-ctx.Done()
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
factory.ServerGracefulStop(networkServer, shutdownCtx)
factory.ServerGracefulStop(socketServer, shutdownCtx)
return nil
})
return errGroup.Wait()
}

View File

@ -400,7 +400,7 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru
LeaveEtcd,
).Append(
"stopServices",
StopServicesForUpgrade,
StopServicesEphemeral,
).Append(
"unmountUser",
UnmountUserDisks,
@ -421,9 +421,6 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru
).Append(
"upgrade",
Upgrade,
).Append(
"stopEverything",
StopAllServices,
).Append(
"mountBoot",
MountBootPartition,
@ -433,6 +430,9 @@ func (*Sequencer) Upgrade(r runtime.Runtime, in *machineapi.UpgradeRequest) []ru
).Append(
"unmountBoot",
UnmountBootPartition,
).Append(
"stopEverything",
StopAllServices,
).Append(
"reboot",
Reboot,
@ -453,8 +453,8 @@ func stopAllPhaselist(r runtime.Runtime, enableKexec bool) PhaseList {
)
default:
phases = phases.Append(
"stopEverything",
StopAllServices,
"stopServices",
StopServicesEphemeral,
).Append(
"unmountUser",
UnmountUserDisks,
@ -481,6 +481,9 @@ func stopAllPhaselist(r runtime.Runtime, enableKexec bool) PhaseList {
enableKexec,
"unmountBoot",
UnmountBootPartition,
).Append(
"stopEverything",
StopAllServices,
)
}

View File

@ -796,11 +796,11 @@ func StartAllServices(seq runtime.Sequence, data interface{}) (runtime.TaskExecu
}, "startAllServices"
}
// StopServicesForUpgrade represents the StopServicesForUpgrade task.
func StopServicesForUpgrade(seq runtime.Sequence, data interface{}) (runtime.TaskExecutionFunc, string) {
// StopServicesEphemeral represents the StopServicesEphemeral task.
func StopServicesEphemeral(seq runtime.Sequence, data interface{}) (runtime.TaskExecutionFunc, string) {
return func(ctx context.Context, logger *log.Logger, r runtime.Runtime) (err error) {
// stopping 'cri' service stops everything which depends on it (kubelet, etcd, ...)
return system.Services(nil).StopWithRevDepenencies(ctx, "cri", "udevd")
return system.Services(nil).StopWithRevDepenencies(ctx, "cri", "udevd", "trustd")
}, "stopServicesForUpgrade"
}

View File

@ -10,6 +10,7 @@ import (
"log"
"os"
"path/filepath"
"time"
v1alpha1server "github.com/talos-systems/talos/internal/app/machined/internal/server/v1alpha1"
"github.com/talos-systems/talos/internal/app/machined/pkg/runtime"
@ -134,8 +135,6 @@ func (s *machinedService) Main(ctx context.Context, r runtime.Runtime, logWriter
return err
}
defer server.Stop()
go func() {
//nolint:errcheck
server.Serve(listener)
@ -143,6 +142,11 @@ func (s *machinedService) Main(ctx context.Context, r runtime.Runtime, logWriter
<-ctx.Done()
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer shutdownCancel()
factory.ServerGracefulStop(server, shutdownCtx)
return nil
}

View File

@ -5,6 +5,7 @@
package factory
import (
"context"
"crypto/tls"
"errors"
"fmt"
@ -257,3 +258,22 @@ func ListenAndServe(r Registrator, setters ...Option) (err error) {
return server.Serve(listener)
}
// ServerGracefulStop the server with a timeout.
//
// Core gRPC doesn't support timeouts.
func ServerGracefulStop(server *grpc.Server, shutdownCtx context.Context) { //nolint:revive
stopped := make(chan struct{})
go func() {
server.GracefulStop()
close(stopped)
}()
select {
case <-shutdownCtx.Done():
server.Stop()
case <-stopped:
server.Stop()
}
}