From 051e77e33946f932e6879bf4df2773e3c998c3a7 Mon Sep 17 00:00:00 2001
From: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
Date: Tue, 25 Jul 2023 20:29:09 +0300
Subject: [PATCH 1/4] virtio/vsock: rework MSG_PEEK for SOCK_STREAM

This reworks current implementation of MSG_PEEK logic:
1) Replaces 'skb_queue_walk_safe()' with 'skb_queue_walk()'. There is
   no need in the first one, as there are no removes of skb in loop.
2) Removes nested while loop - MSG_PEEK logic could be implemented
   without it: just iterate over skbs without removing it and copy
   data from each until destination buffer is not full.

Signed-off-by: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
Reviewed-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
Reviewed-by: Stefano Garzarella <sgarzare@redhat.com>
Acked-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 net/vmw_vsock/virtio_transport_common.c | 47 ++++++++++++-------------
 1 file changed, 22 insertions(+), 25 deletions(-)

diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index b769fc258931..2ee40574c339 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -348,37 +348,34 @@ virtio_transport_stream_do_peek(struct vsock_sock *vsk,
 				size_t len)
 {
 	struct virtio_vsock_sock *vvs = vsk->trans;
-	size_t bytes, total = 0, off;
-	struct sk_buff *skb, *tmp;
-	int err = -EFAULT;
+	struct sk_buff *skb;
+	size_t total = 0;
+	int err;
 
 	spin_lock_bh(&vvs->rx_lock);
 
-	skb_queue_walk_safe(&vvs->rx_queue, skb,  tmp) {
-		off = 0;
+	skb_queue_walk(&vvs->rx_queue, skb) {
+		size_t bytes;
+
+		bytes = len - total;
+		if (bytes > skb->len)
+			bytes = skb->len;
+
+		spin_unlock_bh(&vvs->rx_lock);
+
+		/* sk_lock is held by caller so no one else can dequeue.
+		 * Unlock rx_lock since memcpy_to_msg() may sleep.
+		 */
+		err = memcpy_to_msg(msg, skb->data, bytes);
+		if (err)
+			goto out;
+
+		total += bytes;
+
+		spin_lock_bh(&vvs->rx_lock);
 
 		if (total == len)
 			break;
-
-		while (total < len && off < skb->len) {
-			bytes = len - total;
-			if (bytes > skb->len - off)
-				bytes = skb->len - off;
-
-			/* sk_lock is held by caller so no one else can dequeue.
-			 * Unlock rx_lock since memcpy_to_msg() may sleep.
-			 */
-			spin_unlock_bh(&vvs->rx_lock);
-
-			err = memcpy_to_msg(msg, skb->data + off, bytes);
-			if (err)
-				goto out;
-
-			spin_lock_bh(&vvs->rx_lock);
-
-			total += bytes;
-			off += bytes;
-		}
 	}
 
 	spin_unlock_bh(&vvs->rx_lock);

From a75f501de88e58db24dc8822505f924cd1085ac6 Mon Sep 17 00:00:00 2001
From: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
Date: Tue, 25 Jul 2023 20:29:10 +0300
Subject: [PATCH 2/4] virtio/vsock: support MSG_PEEK for SOCK_SEQPACKET

This adds support of MSG_PEEK flag for SOCK_SEQPACKET type of socket.
Difference with SOCK_STREAM is that this callback returns either length
of the message or error.

Signed-off-by: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
Reviewed-by: Stefano Garzarella <sgarzare@redhat.com>
Acked-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 net/vmw_vsock/virtio_transport_common.c | 63 +++++++++++++++++++++++--
 1 file changed, 60 insertions(+), 3 deletions(-)

diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index 2ee40574c339..352d042b130b 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -460,6 +460,63 @@ out:
 	return err;
 }
 
+static ssize_t
+virtio_transport_seqpacket_do_peek(struct vsock_sock *vsk,
+				   struct msghdr *msg)
+{
+	struct virtio_vsock_sock *vvs = vsk->trans;
+	struct sk_buff *skb;
+	size_t total, len;
+
+	spin_lock_bh(&vvs->rx_lock);
+
+	if (!vvs->msg_count) {
+		spin_unlock_bh(&vvs->rx_lock);
+		return 0;
+	}
+
+	total = 0;
+	len = msg_data_left(msg);
+
+	skb_queue_walk(&vvs->rx_queue, skb) {
+		struct virtio_vsock_hdr *hdr;
+
+		if (total < len) {
+			size_t bytes;
+			int err;
+
+			bytes = len - total;
+			if (bytes > skb->len)
+				bytes = skb->len;
+
+			spin_unlock_bh(&vvs->rx_lock);
+
+			/* sk_lock is held by caller so no one else can dequeue.
+			 * Unlock rx_lock since memcpy_to_msg() may sleep.
+			 */
+			err = memcpy_to_msg(msg, skb->data, bytes);
+			if (err)
+				return err;
+
+			spin_lock_bh(&vvs->rx_lock);
+		}
+
+		total += skb->len;
+		hdr = virtio_vsock_hdr(skb);
+
+		if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
+			if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
+				msg->msg_flags |= MSG_EOR;
+
+			break;
+		}
+	}
+
+	spin_unlock_bh(&vvs->rx_lock);
+
+	return total;
+}
+
 static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
 						 struct msghdr *msg,
 						 int flags)
@@ -554,9 +611,9 @@ virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
 				   int flags)
 {
 	if (flags & MSG_PEEK)
-		return -EOPNOTSUPP;
-
-	return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
+		return virtio_transport_seqpacket_do_peek(vsk, msg);
+	else
+		return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
 }
 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
 

From 587ed79f62a7569cbe386bbe6102800992f73001 Mon Sep 17 00:00:00 2001
From: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
Date: Tue, 25 Jul 2023 20:29:11 +0300
Subject: [PATCH 3/4] vsock/test: rework MSG_PEEK test for SOCK_STREAM

This new version makes test more complicated by adding empty read,
partial read and data comparisons between MSG_PEEK and normal reads.

Signed-off-by: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
Reviewed-by: Stefano Garzarella <sgarzare@redhat.com>
Acked-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 tools/testing/vsock/vsock_test.c | 78 ++++++++++++++++++++++++++++++--
 1 file changed, 75 insertions(+), 3 deletions(-)

diff --git a/tools/testing/vsock/vsock_test.c b/tools/testing/vsock/vsock_test.c
index ac1bd3ac1533..444a3ff0681f 100644
--- a/tools/testing/vsock/vsock_test.c
+++ b/tools/testing/vsock/vsock_test.c
@@ -255,9 +255,14 @@ static void test_stream_multiconn_server(const struct test_opts *opts)
 		close(fds[i]);
 }
 
+#define MSG_PEEK_BUF_LEN 64
+
 static void test_stream_msg_peek_client(const struct test_opts *opts)
 {
+	unsigned char buf[MSG_PEEK_BUF_LEN];
+	ssize_t send_size;
 	int fd;
+	int i;
 
 	fd = vsock_stream_connect(opts->peer_cid, 1234);
 	if (fd < 0) {
@@ -265,12 +270,32 @@ static void test_stream_msg_peek_client(const struct test_opts *opts)
 		exit(EXIT_FAILURE);
 	}
 
-	send_byte(fd, 1, 0);
+	for (i = 0; i < sizeof(buf); i++)
+		buf[i] = rand() & 0xFF;
+
+	control_expectln("SRVREADY");
+
+	send_size = send(fd, buf, sizeof(buf), 0);
+
+	if (send_size < 0) {
+		perror("send");
+		exit(EXIT_FAILURE);
+	}
+
+	if (send_size != sizeof(buf)) {
+		fprintf(stderr, "Invalid send size %zi\n", send_size);
+		exit(EXIT_FAILURE);
+	}
+
 	close(fd);
 }
 
 static void test_stream_msg_peek_server(const struct test_opts *opts)
 {
+	unsigned char buf_half[MSG_PEEK_BUF_LEN / 2];
+	unsigned char buf_normal[MSG_PEEK_BUF_LEN];
+	unsigned char buf_peek[MSG_PEEK_BUF_LEN];
+	ssize_t res;
 	int fd;
 
 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
@@ -279,8 +304,55 @@ static void test_stream_msg_peek_server(const struct test_opts *opts)
 		exit(EXIT_FAILURE);
 	}
 
-	recv_byte(fd, 1, MSG_PEEK);
-	recv_byte(fd, 1, 0);
+	/* Peek from empty socket. */
+	res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK | MSG_DONTWAIT);
+	if (res != -1) {
+		fprintf(stderr, "expected recv(2) failure, got %zi\n", res);
+		exit(EXIT_FAILURE);
+	}
+
+	if (errno != EAGAIN) {
+		perror("EAGAIN expected");
+		exit(EXIT_FAILURE);
+	}
+
+	control_writeln("SRVREADY");
+
+	/* Peek part of data. */
+	res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK);
+	if (res != sizeof(buf_half)) {
+		fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n",
+			sizeof(buf_half), res);
+		exit(EXIT_FAILURE);
+	}
+
+	/* Peek whole data. */
+	res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK);
+	if (res != sizeof(buf_peek)) {
+		fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n",
+			sizeof(buf_peek), res);
+		exit(EXIT_FAILURE);
+	}
+
+	/* Compare partial and full peek. */
+	if (memcmp(buf_half, buf_peek, sizeof(buf_half))) {
+		fprintf(stderr, "Partial peek data mismatch\n");
+		exit(EXIT_FAILURE);
+	}
+
+	res = recv(fd, buf_normal, sizeof(buf_normal), 0);
+	if (res != sizeof(buf_normal)) {
+		fprintf(stderr, "recv(2), expected %zu, got %zi\n",
+			sizeof(buf_normal), res);
+		exit(EXIT_FAILURE);
+	}
+
+	/* Compare full peek and normal read. */
+	if (memcmp(buf_peek, buf_normal, sizeof(buf_peek))) {
+		fprintf(stderr, "Full peek data mismatch\n");
+		exit(EXIT_FAILURE);
+	}
+
 	close(fd);
 }
 

From 8a0697f23e5a6e03a31f9dfb96755521aa9b9fc1 Mon Sep 17 00:00:00 2001
From: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
Date: Tue, 25 Jul 2023 20:29:12 +0300
Subject: [PATCH 4/4] vsock/test: MSG_PEEK test for SOCK_SEQPACKET

This adds MSG_PEEK test for SOCK_SEQPACKET. It works in the same way as
SOCK_STREAM test, except it also tests MSG_TRUNC flag.

Signed-off-by: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
Reviewed-by: Stefano Garzarella <sgarzare@redhat.com>
Acked-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
---
 tools/testing/vsock/vsock_test.c | 58 +++++++++++++++++++++++++++++---
 1 file changed, 54 insertions(+), 4 deletions(-)

diff --git a/tools/testing/vsock/vsock_test.c b/tools/testing/vsock/vsock_test.c
index 444a3ff0681f..90718c2fd4ea 100644
--- a/tools/testing/vsock/vsock_test.c
+++ b/tools/testing/vsock/vsock_test.c
@@ -257,14 +257,19 @@ static void test_stream_multiconn_server(const struct test_opts *opts)
 
 #define MSG_PEEK_BUF_LEN 64
 
-static void test_stream_msg_peek_client(const struct test_opts *opts)
+static void test_msg_peek_client(const struct test_opts *opts,
+				 bool seqpacket)
 {
 	unsigned char buf[MSG_PEEK_BUF_LEN];
 	ssize_t send_size;
 	int fd;
 	int i;
 
-	fd = vsock_stream_connect(opts->peer_cid, 1234);
+	if (seqpacket)
+		fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
+	else
+		fd = vsock_stream_connect(opts->peer_cid, 1234);
+
 	if (fd < 0) {
 		perror("connect");
 		exit(EXIT_FAILURE);
@@ -290,7 +295,8 @@ static void test_stream_msg_peek_client(const struct test_opts *opts)
 	close(fd);
 }
 
-static void test_stream_msg_peek_server(const struct test_opts *opts)
+static void test_msg_peek_server(const struct test_opts *opts,
+				 bool seqpacket)
 {
 	unsigned char buf_half[MSG_PEEK_BUF_LEN / 2];
 	unsigned char buf_normal[MSG_PEEK_BUF_LEN];
@@ -298,7 +304,11 @@ static void test_stream_msg_peek_server(const struct test_opts *opts)
 	ssize_t res;
 	int fd;
 
-	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
+	if (seqpacket)
+		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
+	else
+		fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
+
 	if (fd < 0) {
 		perror("accept");
 		exit(EXIT_FAILURE);
@@ -340,6 +350,21 @@ static void test_stream_msg_peek_server(const struct test_opts *opts)
 		exit(EXIT_FAILURE);
 	}
 
+	if (seqpacket) {
+		/* This type of socket supports MSG_TRUNC flag,
+		 * so check it with MSG_PEEK. We must get length
+		 * of the message.
+		 */
+		res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK |
+			   MSG_TRUNC);
+		if (res != sizeof(buf_peek)) {
+			fprintf(stderr,
+				"recv(2) + MSG_PEEK | MSG_TRUNC, exp %zu, got %zi\n",
+				sizeof(buf_half), res);
+			exit(EXIT_FAILURE);
+		}
+	}
+
 	res = recv(fd, buf_normal, sizeof(buf_normal), 0);
 	if (res != sizeof(buf_normal)) {
 		fprintf(stderr, "recv(2), expected %zu, got %zi\n",
@@ -356,6 +381,16 @@ static void test_stream_msg_peek_server(const struct test_opts *opts)
 	close(fd);
 }
 
+static void test_stream_msg_peek_client(const struct test_opts *opts)
+{
+	return test_msg_peek_client(opts, false);
+}
+
+static void test_stream_msg_peek_server(const struct test_opts *opts)
+{
+	return test_msg_peek_server(opts, false);
+}
+
 #define SOCK_BUF_SIZE (2 * 1024 * 1024)
 #define MAX_MSG_SIZE (32 * 1024)
 
@@ -1125,6 +1160,16 @@ static void test_stream_virtio_skb_merge_server(const struct test_opts *opts)
 	close(fd);
 }
 
+static void test_seqpacket_msg_peek_client(const struct test_opts *opts)
+{
+	return test_msg_peek_client(opts, true);
+}
+
+static void test_seqpacket_msg_peek_server(const struct test_opts *opts)
+{
+	return test_msg_peek_server(opts, true);
+}
+
 static struct test_case test_cases[] = {
 	{
 		.name = "SOCK_STREAM connection reset",
@@ -1200,6 +1245,11 @@ static struct test_case test_cases[] = {
 		.run_client = test_stream_virtio_skb_merge_client,
 		.run_server = test_stream_virtio_skb_merge_server,
 	},
+	{
+		.name = "SOCK_SEQPACKET MSG_PEEK",
+		.run_client = test_seqpacket_msg_peek_client,
+		.run_server = test_seqpacket_msg_peek_server,
+	},
 	{},
 };