chore: remove <-errCh where possible in grpc methods

Simplify code by passing error directly into the pipe closer.

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
This commit is contained in:
Dmitriy Matrenichev 2023-08-07 20:46:55 +03:00
parent e0f383598e
commit c4a1ca8d61
No known key found for this signature in database
GPG Key ID: D3363CF894E68892
14 changed files with 64 additions and 201 deletions

View File

@ -10,7 +10,6 @@ import (
"io"
"os"
"path/filepath"
"sync"
"github.com/spf13/cobra"
@ -48,23 +47,11 @@ captures ownership and permission bits.`,
return err
}
r, errCh, err := c.Copy(ctx, args[0])
r, err := c.Copy(ctx, args[0])
if err != nil {
return fmt.Errorf("error copying: %w", err)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for err := range errCh {
fmt.Fprintln(os.Stderr, err.Error())
}
}()
defer wg.Wait()
localPath := args[1]
if localPath == "-" {

View File

@ -11,7 +11,6 @@ import (
"io"
"os"
"strings"
"sync"
"text/tabwriter"
"github.com/dustin/go-humanize"
@ -350,25 +349,13 @@ var etcdSnapshotCmd = &cobra.Command{
defer dest.Close() //nolint:errcheck
r, errCh, err := c.EtcdSnapshot(ctx, &machine.EtcdSnapshotRequest{})
r, err := c.EtcdSnapshot(ctx, &machine.EtcdSnapshotRequest{})
if err != nil {
return fmt.Errorf("error reading file: %w", err)
}
defer r.Close() //nolint:errcheck
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for err := range errCh {
fmt.Fprintln(os.Stderr, err.Error())
}
}()
defer wg.Wait()
size, err := io.Copy(dest, r)
if err != nil {
return fmt.Errorf("error reading: %w", err)

View File

@ -11,7 +11,6 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"github.com/mattn/go-isatty"
"github.com/siderolabs/go-kubeconfig"
@ -92,22 +91,11 @@ Otherwise kubeconfig will be written to PWD or [local-path] if specified.`,
}
}
r, errCh, err := c.KubeconfigRaw(ctx)
r, err := c.KubeconfigRaw(ctx)
if err != nil {
return fmt.Errorf("error copying: %w", err)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for err := range errCh {
fmt.Fprintln(os.Stderr, err.Error())
}
}()
defer wg.Wait()
defer r.Close() //nolint:errcheck
data, err := helpers.ExtractFileFromTarGz("kubeconfig", r)

View File

@ -11,7 +11,6 @@ import (
"io"
"os"
"strings"
"sync"
"time"
"github.com/gopacket/gopacket"
@ -90,27 +89,11 @@ e.g. by excluding packets with the port 50000.
return err
}
r, errCh, err := c.PacketCapture(ctx, &req)
r, err := c.PacketCapture(ctx, &req)
if err != nil {
return fmt.Errorf("error copying: %w", err)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for err := range errCh {
if client.StatusCode(err) == codes.DeadlineExceeded {
continue
}
fmt.Fprintln(os.Stderr, err.Error())
}
}()
defer wg.Wait()
if pcapCmdFlags.output == "" {
return dumpPackets(r)
}

View File

@ -11,7 +11,6 @@ import (
"os"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
"github.com/siderolabs/talos/cmd/talosctl/pkg/talos/helpers"
"github.com/siderolabs/talos/pkg/machinery/client"
@ -37,37 +36,19 @@ var readCmd = &cobra.Command{
return err
}
r, errCh, err := c.Read(ctx, args[0])
r, err := c.Read(ctx, args[0])
if err != nil {
return fmt.Errorf("error reading file: %w", err)
}
defer r.Close() //nolint:errcheck
var eg errgroup.Group
eg.Go(func() error {
var errors error
for err := range errCh {
if err != nil {
errors = helpers.AppendErrors(errors, err)
}
}
return errors
})
_, err = io.Copy(os.Stdout, r)
if err != nil {
return fmt.Errorf("error reading: %w", err)
}
if err = r.Close(); err != nil {
return err
}
return eg.Wait()
return r.Close()
})
},
}

View File

@ -320,7 +320,7 @@ func (a *Tracker) processNodeUpdate(update nodeUpdate) reporter.Update {
// getBootID reads the boot ID from the node.
// It returns the node as the first return value and the boot ID as the second.
func getBootID(ctx context.Context, c *client.Client) (string, error) {
reader, errCh, err := c.Read(ctx, "/proc/sys/kernel/random/boot_id")
reader, err := c.Read(ctx, "/proc/sys/kernel/random/boot_id")
if err != nil {
return "", err
}
@ -334,11 +334,5 @@ func getBootID(ctx context.Context, c *client.Client) (string, error) {
bootID := strings.TrimSpace(string(body))
for err = range errCh {
if err != nil {
return "", err
}
}
return bootID, reader.Close()
}

View File

@ -128,7 +128,7 @@ func (suite *CGroupsSuite) TestCGroupsVersion() {
//nolint:gocyclo
func (suite *CGroupsSuite) readCmdline(ctx context.Context) (string, error) {
reader, errCh, err := suite.Client.Read(ctx, "/proc/cmdline")
reader, err := suite.Client.Read(ctx, "/proc/cmdline")
if err != nil {
return "", err
}
@ -147,12 +147,6 @@ func (suite *CGroupsSuite) readCmdline(ctx context.Context) (string, error) {
return "", err
}
for err = range errCh {
if err != nil {
return "", err
}
}
return bootID, reader.Close()
}

View File

@ -51,14 +51,12 @@ func (suite *DmesgSuite) TestNodeHasDmesg() {
)
suite.Require().NoError(err)
logReader, errCh, err := client.ReadStream(dmesgStream)
logReader, err := client.ReadStream(dmesgStream)
suite.Require().NoError(err)
n, err := io.Copy(io.Discard, logReader)
suite.Require().NoError(err)
suite.Require().NoError(<-errCh)
// dmesg shouldn't be empty
suite.Require().Greater(n, int64(1024))
}

View File

@ -11,7 +11,6 @@ import (
"context"
"fmt"
"io"
"sync"
"testing"
"time"
@ -184,27 +183,13 @@ func (suite *EtcdRecoverSuite) TestSnapshotRecover() {
func (suite *EtcdRecoverSuite) snapshotEtcd(snapshotNode string, dest io.Writer) error {
ctx := client.WithNodes(suite.ctx, snapshotNode)
r, errCh, err := suite.Client.EtcdSnapshot(ctx, &machineapi.EtcdSnapshotRequest{})
r, err := suite.Client.EtcdSnapshot(ctx, &machineapi.EtcdSnapshotRequest{})
if err != nil {
return fmt.Errorf("error reading snapshot: %w", err)
}
defer r.Close() //nolint:errcheck
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for err := range errCh {
suite.T().Logf("read error: %s", err)
}
}()
defer wg.Wait()
_, err = io.Copy(dest, r)
return err

View File

@ -69,15 +69,13 @@ func (suite *LogsSuite) TestServicesHaveLogs() {
)
suite.Require().NoError(err)
logReader, errCh, err := client.ReadStream(logsStream)
logReader, err := client.ReadStream(logsStream)
suite.Require().NoError(err)
n, err := io.Copy(io.Discard, logReader)
suite.Require().NoError(err)
logsSize += n
suite.Require().NoError(<-errCh)
}
// overall logs shouldn't be empty
@ -104,7 +102,7 @@ func (suite *LogsSuite) TestTail() {
)
suite.Require().NoError(err)
logReader, errCh, err := client.ReadStream(logsStream)
logReader, err := client.ReadStream(logsStream)
suite.Require().NoError(err)
scanner := bufio.NewScanner(logReader)
@ -116,8 +114,6 @@ func (suite *LogsSuite) TestTail() {
suite.Require().NoError(scanner.Err())
suite.Require().NoError(<-errCh)
suite.Assert().EqualValues(tailLines, lines)
}
}

View File

@ -15,7 +15,6 @@ import (
"io"
"math/rand"
"strings"
"sync"
"time"
"github.com/siderolabs/go-retry/retry"
@ -200,7 +199,7 @@ func (apiSuite *APISuite) ReadBootID(ctx context.Context) (string, error) {
reqCtx, reqCtxCancel := context.WithTimeout(ctx, 10*time.Second)
defer reqCtxCancel()
reader, errCh, err := apiSuite.Client.Read(reqCtx, "/proc/sys/kernel/random/boot_id")
reader, err := apiSuite.Client.Read(reqCtx, "/proc/sys/kernel/random/boot_id")
if err != nil {
return "", err
}
@ -219,12 +218,6 @@ func (apiSuite *APISuite) ReadBootID(ctx context.Context) (string, error) {
return "", err
}
for err = range errCh {
if err != nil {
return "", err
}
}
return bootID, reader.Close()
}
@ -411,7 +404,7 @@ func (apiSuite *APISuite) HashKubeletCert(ctx context.Context, node string) (str
reqCtx = client.WithNodes(reqCtx, node)
reader, errCh, err := apiSuite.Client.Read(reqCtx, "/var/lib/kubelet/pki/kubelet-client-current.pem")
reader, err := apiSuite.Client.Read(reqCtx, "/var/lib/kubelet/pki/kubelet-client-current.pem")
if err != nil {
return "", err
}
@ -425,12 +418,6 @@ func (apiSuite *APISuite) HashKubeletCert(ctx context.Context, node string) (str
return "", err
}
for err = range errCh {
if err != nil {
return "", err
}
}
return hex.EncodeToString(hash.Sum(nil)), reader.Close()
}
@ -439,13 +426,13 @@ func (apiSuite *APISuite) ReadConfigFromNode(nodeCtx context.Context) (config.Pr
// Load the current node machine config
cfgData := new(bytes.Buffer)
reader, errCh, err := apiSuite.Client.Read(nodeCtx, constants.ConfigPath)
reader, err := apiSuite.Client.Read(nodeCtx, constants.ConfigPath)
if err != nil {
return nil, fmt.Errorf("error creating reader: %w", err)
}
defer reader.Close() //nolint:errcheck
if err = copyFromReaderWithErrChan(cfgData, reader, errCh); err != nil {
if _, err = io.Copy(cfgData, reader); err != nil {
return nil, fmt.Errorf("error reading: %w", err)
}
@ -457,34 +444,6 @@ func (apiSuite *APISuite) ReadConfigFromNode(nodeCtx context.Context) (config.Pr
return provider, nil
}
func copyFromReaderWithErrChan(out io.Writer, in io.Reader, errCh <-chan error) (err error) {
var wg sync.WaitGroup
var chanErr error
wg.Add(1)
go func() {
defer wg.Done()
// StreamReader is only singly-buffered, so we need to process any errors as we get them.
for chanErr = range errCh { //nolint:revive
}
}()
defer func() {
wg.Wait()
if err == nil {
err = chanErr
}
}()
_, err = io.Copy(out, in)
return err
}
// TearDownSuite closes Talos API client.
func (apiSuite *APISuite) TearDownSuite() {
if apiSuite.Client != nil {

View File

@ -75,7 +75,7 @@ func (s *APICrashDumper) CrashDump(ctx context.Context, out io.Writer) {
continue
}
r, errCh, err := client.ReadStream(stream)
r, err := client.ReadStream(stream)
if err != nil {
fmt.Fprintf(out, "error getting service logs for %s: %s\n", svc.Id, err)
@ -89,11 +89,6 @@ func (s *APICrashDumper) CrashDump(ctx context.Context, out io.Writer) {
fmt.Fprintf(out, "error streaming service logs: %s\n", err)
}
err = <-errCh
if err != nil {
fmt.Fprintf(out, "error streaming service logs: %s\n", err)
}
r.Close() //nolint:errcheck
}
}

View File

@ -627,7 +627,7 @@ func mounts(ctx context.Context, options *BundleOptions) ([]byte, error) {
func devices(ctx context.Context, options *BundleOptions) ([]byte, error) {
options.Log("reading devices")
r, _, err := options.Client.Read(ctx, "/proc/bus/pci/devices")
r, err := options.Client.Read(ctx, "/proc/bus/pci/devices")
if err != nil {
return nil, err
}

View File

@ -178,10 +178,13 @@ func (c *Client) Close() error {
}
// KubeconfigRaw returns K8s client config (kubeconfig).
func (c *Client) KubeconfigRaw(ctx context.Context) (io.ReadCloser, <-chan error, error) {
//
// This method doesn't support multiplexing of the result:
// * either client.WithNodes is not used, or it contains a single node in the list.
func (c *Client) KubeconfigRaw(ctx context.Context) (io.ReadCloser, error) {
stream, err := c.MachineClient.Kubeconfig(ctx, &emptypb.Empty{})
if err != nil {
return nil, nil, err
return nil, err
}
return ReadStream(stream)
@ -225,20 +228,12 @@ func (c *Client) extractKubeconfig(r io.ReadCloser) ([]byte, error) {
// Kubeconfig returns K8s client config (kubeconfig).
func (c *Client) Kubeconfig(ctx context.Context) ([]byte, error) {
r, errCh, err := c.KubeconfigRaw(ctx)
r, err := c.KubeconfigRaw(ctx)
if err != nil {
return nil, err
}
kubeconfig, err := c.extractKubeconfig(r)
if err2 := <-errCh; err2 != nil {
// prefer errCh (error from server) as if server failed,
// extractKubeconfig failed as well, but server failure is more descriptive
return nil, err2
}
return kubeconfig, err
return c.extractKubeconfig(r)
}
// ApplyConfiguration implements proto.MachineServiceClient interface.
@ -532,12 +527,15 @@ func (c *Client) DiskUsage(ctx context.Context, req *machineapi.DiskUsageRequest
}
// Copy implements the proto.MachineServiceClient interface.
func (c *Client) Copy(ctx context.Context, rootPath string) (io.ReadCloser, <-chan error, error) {
//
// This method doesn't support multiplexing of the result:
// * either client.WithNodes is not used, or it contains a single node in the list.
func (c *Client) Copy(ctx context.Context, rootPath string) (io.ReadCloser, error) {
stream, err := c.MachineClient.Copy(ctx, &machineapi.CopyRequest{
RootPath: rootPath,
})
if err != nil {
return nil, nil, err
return nil, err
}
return ReadStream(stream)
@ -762,10 +760,13 @@ func (c *Client) TimeCheck(ctx context.Context, server string, callOptions ...gr
}
// Read reads a file.
func (c *Client) Read(ctx context.Context, path string) (io.ReadCloser, <-chan error, error) {
//
// This method doesn't support multiplexing of the result:
// * either client.WithNodes is not used, or it contains a single node in the list.
func (c *Client) Read(ctx context.Context, path string) (io.ReadCloser, error) {
stream, err := c.MachineClient.Read(ctx, &machineapi.ReadRequest{Path: path})
if err != nil {
return nil, nil, err
return nil, err
}
return ReadStream(stream)
@ -837,10 +838,13 @@ func (c *Client) EtcdMemberList(ctx context.Context, req *machineapi.EtcdMemberL
}
// EtcdSnapshot receives a snapshot of the etcd from the node.
func (c *Client) EtcdSnapshot(ctx context.Context, req *machineapi.EtcdSnapshotRequest, callOptions ...grpc.CallOption) (io.ReadCloser, <-chan error, error) {
//
// This method doesn't support multiplexing of the result:
// * either client.WithNodes is not used, or it contains a single node in the list.
func (c *Client) EtcdSnapshot(ctx context.Context, req *machineapi.EtcdSnapshotRequest, callOptions ...grpc.CallOption) (io.ReadCloser, error) {
stream, err := c.MachineClient.EtcdSnapshot(ctx, req, callOptions...)
if err != nil {
return nil, nil, err
return nil, err
}
return ReadStream(stream)
@ -960,10 +964,13 @@ func (c *Client) GenerateClientConfiguration(ctx context.Context, req *machineap
}
// PacketCapture implements the proto.MachineServiceClient interface.
func (c *Client) PacketCapture(ctx context.Context, req *machineapi.PacketCaptureRequest) (io.ReadCloser, <-chan error, error) {
//
// This method doesn't support multiplexing of the result:
// * either client.WithNodes is not used, or it contains a single node in the list.
func (c *Client) PacketCapture(ctx context.Context, req *machineapi.PacketCaptureRequest) (io.ReadCloser, error) {
stream, err := c.MachineClient.PacketCapture(ctx, req)
if err != nil {
return nil, nil, err
return nil, err
}
return ReadStream(stream)
@ -978,19 +985,17 @@ type MachineStream interface {
// ReadStream converts grpc stream into io.Reader.
//
//nolint:gocyclo
func ReadStream(stream MachineStream) (io.ReadCloser, <-chan error, error) {
errCh := make(chan error, 1)
func ReadStream(stream MachineStream) (io.ReadCloser, error) {
pr, pw := io.Pipe()
go func() {
//nolint:errcheck
defer pw.Close()
defer close(errCh)
for {
data, err := stream.Recv()
if err != nil {
if err == io.EOF || StatusCode(err) == codes.Canceled || StatusCode(err) == codes.DeadlineExceeded {
if errors.Is(err, io.EOF) || StatusCode(err) == codes.Canceled || StatusCode(err) == codes.DeadlineExceeded {
return
}
//nolint:errcheck
@ -1007,16 +1012,27 @@ func ReadStream(stream MachineStream) (io.ReadCloser, <-chan error, error) {
}
if data.Metadata != nil && data.Metadata.Error != "" {
if data.Metadata.Status != nil {
errCh <- status.FromProto(data.Metadata.Status).Err()
} else {
errCh <- errors.New(data.Metadata.Error)
}
pw.CloseWithError(metaToErr(data.Metadata)) //nolint:errcheck
return
}
}
}()
return pr, errCh, stream.CloseSend()
return pr, stream.CloseSend()
}
func metaToErr(md *common.Metadata) error {
if md.Status == nil {
return errors.New(md.Error)
}
code := codes.Code(md.Status.Code)
if code == codes.Canceled || code == codes.DeadlineExceeded {
return nil
}
return status.FromProto(md.Status).Err()
}
// Netstat lists the network sockets on the current node.