linux/net/vmw_vsock/virtio_transport_common.c
John Fastabend 78fa0d61d9 bpf, sockmap: Pass skb ownership through read_skb
The read_skb hook calls consume_skb() now, but this means that if the
recv_actor program wants to use the skb it needs to inc the ref cnt
so that the consume_skb() doesn't kfree the sk_buff.

This is problematic because in some error cases under memory pressure
we may need to linearize the sk_buff from sk_psock_skb_ingress_enqueue().
Then we get this,

 skb_linearize()
   __pskb_pull_tail()
     pskb_expand_head()
       BUG_ON(skb_shared(skb))

Because we incremented users refcnt from sk_psock_verdict_recv() we
hit the bug on with refcnt > 1 and trip it.

To fix lets simply pass ownership of the sk_buff through the skb_read
call. Then we can drop the consume from read_skb handlers and assume
the verdict recv does any required kfree.

Bug found while testing in our CI which runs in VMs that hit memory
constraints rather regularly. William tested TCP read_skb handlers.

[  106.536188] ------------[ cut here ]------------
[  106.536197] kernel BUG at net/core/skbuff.c:1693!
[  106.536479] invalid opcode: 0000 [#1] PREEMPT SMP PTI
[  106.536726] CPU: 3 PID: 1495 Comm: curl Not tainted 5.19.0-rc5 #1
[  106.537023] Hardware name: QEMU Standard PC (i440FX + PIIX, 1996), BIOS ArchLinux 1.16.0-1 04/01/2014
[  106.537467] RIP: 0010:pskb_expand_head+0x269/0x330
[  106.538585] RSP: 0018:ffffc90000138b68 EFLAGS: 00010202
[  106.538839] RAX: 000000000000003f RBX: ffff8881048940e8 RCX: 0000000000000a20
[  106.539186] RDX: 0000000000000002 RSI: 0000000000000000 RDI: ffff8881048940e8
[  106.539529] RBP: ffffc90000138be8 R08: 00000000e161fd1a R09: 0000000000000000
[  106.539877] R10: 0000000000000018 R11: 0000000000000000 R12: ffff8881048940e8
[  106.540222] R13: 0000000000000003 R14: 0000000000000000 R15: ffff8881048940e8
[  106.540568] FS:  00007f277dde9f00(0000) GS:ffff88813bd80000(0000) knlGS:0000000000000000
[  106.540954] CS:  0010 DS: 0000 ES: 0000 CR0: 0000000080050033
[  106.541227] CR2: 00007f277eeede64 CR3: 000000000ad3e000 CR4: 00000000000006e0
[  106.541569] DR0: 0000000000000000 DR1: 0000000000000000 DR2: 0000000000000000
[  106.541915] DR3: 0000000000000000 DR6: 00000000fffe0ff0 DR7: 0000000000000400
[  106.542255] Call Trace:
[  106.542383]  <IRQ>
[  106.542487]  __pskb_pull_tail+0x4b/0x3e0
[  106.542681]  skb_ensure_writable+0x85/0xa0
[  106.542882]  sk_skb_pull_data+0x18/0x20
[  106.543084]  bpf_prog_b517a65a242018b0_bpf_skskb_http_verdict+0x3a9/0x4aa9
[  106.543536]  ? migrate_disable+0x66/0x80
[  106.543871]  sk_psock_verdict_recv+0xe2/0x310
[  106.544258]  ? sk_psock_write_space+0x1f0/0x1f0
[  106.544561]  tcp_read_skb+0x7b/0x120
[  106.544740]  tcp_data_queue+0x904/0xee0
[  106.544931]  tcp_rcv_established+0x212/0x7c0
[  106.545142]  tcp_v4_do_rcv+0x174/0x2a0
[  106.545326]  tcp_v4_rcv+0xe70/0xf60
[  106.545500]  ip_protocol_deliver_rcu+0x48/0x290
[  106.545744]  ip_local_deliver_finish+0xa7/0x150

Fixes: 04919bed948dc ("tcp: Introduce tcp_read_skb()")
Reported-by: William Findlay <will@isovalent.com>
Signed-off-by: John Fastabend <john.fastabend@gmail.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Tested-by: William Findlay <will@isovalent.com>
Reviewed-by: Jakub Sitnicki <jakub@cloudflare.com>
Link: https://lore.kernel.org/bpf/20230523025618.113937-2-john.fastabend@gmail.com
2023-05-23 16:09:47 +02:00

1463 lines
34 KiB
C

// SPDX-License-Identifier: GPL-2.0-only
/*
* common code for virtio vsock
*
* Copyright (C) 2013-2015 Red Hat, Inc.
* Author: Asias He <asias@redhat.com>
* Stefan Hajnoczi <stefanha@redhat.com>
*/
#include <linux/spinlock.h>
#include <linux/module.h>
#include <linux/sched/signal.h>
#include <linux/ctype.h>
#include <linux/list.h>
#include <linux/virtio_vsock.h>
#include <uapi/linux/vsockmon.h>
#include <net/sock.h>
#include <net/af_vsock.h>
#define CREATE_TRACE_POINTS
#include <trace/events/vsock_virtio_transport_common.h>
/* How long to wait for graceful shutdown of a connection */
#define VSOCK_CLOSE_TIMEOUT (8 * HZ)
/* Threshold for detecting small packets to copy */
#define GOOD_COPY_LEN 128
static const struct virtio_transport *
virtio_transport_get_ops(struct vsock_sock *vsk)
{
const struct vsock_transport *t = vsock_core_get_transport(vsk);
if (WARN_ON(!t))
return NULL;
return container_of(t, struct virtio_transport, transport);
}
/* Returns a new packet on success, otherwise returns NULL.
*
* If NULL is returned, errp is set to a negative errno.
*/
static struct sk_buff *
virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
size_t len,
u32 src_cid,
u32 src_port,
u32 dst_cid,
u32 dst_port)
{
const size_t skb_len = VIRTIO_VSOCK_SKB_HEADROOM + len;
struct virtio_vsock_hdr *hdr;
struct sk_buff *skb;
void *payload;
int err;
skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL);
if (!skb)
return NULL;
hdr = virtio_vsock_hdr(skb);
hdr->type = cpu_to_le16(info->type);
hdr->op = cpu_to_le16(info->op);
hdr->src_cid = cpu_to_le64(src_cid);
hdr->dst_cid = cpu_to_le64(dst_cid);
hdr->src_port = cpu_to_le32(src_port);
hdr->dst_port = cpu_to_le32(dst_port);
hdr->flags = cpu_to_le32(info->flags);
hdr->len = cpu_to_le32(len);
if (info->msg && len > 0) {
payload = skb_put(skb, len);
err = memcpy_from_msg(payload, info->msg, len);
if (err)
goto out;
if (msg_data_left(info->msg) == 0 &&
info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
if (info->msg->msg_flags & MSG_EOR)
hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
}
}
if (info->reply)
virtio_vsock_skb_set_reply(skb);
trace_virtio_transport_alloc_pkt(src_cid, src_port,
dst_cid, dst_port,
len,
info->type,
info->op,
info->flags);
if (info->vsk && !skb_set_owner_sk_safe(skb, sk_vsock(info->vsk))) {
WARN_ONCE(1, "failed to allocate skb on vsock socket with sk_refcnt == 0\n");
goto out;
}
return skb;
out:
kfree_skb(skb);
return NULL;
}
/* Packet capture */
static struct sk_buff *virtio_transport_build_skb(void *opaque)
{
struct virtio_vsock_hdr *pkt_hdr;
struct sk_buff *pkt = opaque;
struct af_vsockmon_hdr *hdr;
struct sk_buff *skb;
size_t payload_len;
void *payload_buf;
/* A packet could be split to fit the RX buffer, so we can retrieve
* the payload length from the header and the buffer pointer taking
* care of the offset in the original packet.
*/
pkt_hdr = virtio_vsock_hdr(pkt);
payload_len = pkt->len;
payload_buf = pkt->data;
skb = alloc_skb(sizeof(*hdr) + sizeof(*pkt_hdr) + payload_len,
GFP_ATOMIC);
if (!skb)
return NULL;
hdr = skb_put(skb, sizeof(*hdr));
/* pkt->hdr is little-endian so no need to byteswap here */
hdr->src_cid = pkt_hdr->src_cid;
hdr->src_port = pkt_hdr->src_port;
hdr->dst_cid = pkt_hdr->dst_cid;
hdr->dst_port = pkt_hdr->dst_port;
hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
hdr->len = cpu_to_le16(sizeof(*pkt_hdr));
memset(hdr->reserved, 0, sizeof(hdr->reserved));
switch (le16_to_cpu(pkt_hdr->op)) {
case VIRTIO_VSOCK_OP_REQUEST:
case VIRTIO_VSOCK_OP_RESPONSE:
hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
break;
case VIRTIO_VSOCK_OP_RST:
case VIRTIO_VSOCK_OP_SHUTDOWN:
hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
break;
case VIRTIO_VSOCK_OP_RW:
hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
break;
case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
break;
default:
hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
break;
}
skb_put_data(skb, pkt_hdr, sizeof(*pkt_hdr));
if (payload_len) {
skb_put_data(skb, payload_buf, payload_len);
}
return skb;
}
void virtio_transport_deliver_tap_pkt(struct sk_buff *skb)
{
if (virtio_vsock_skb_tap_delivered(skb))
return;
vsock_deliver_tap(virtio_transport_build_skb, skb);
virtio_vsock_skb_set_tap_delivered(skb);
}
EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
static u16 virtio_transport_get_type(struct sock *sk)
{
if (sk->sk_type == SOCK_STREAM)
return VIRTIO_VSOCK_TYPE_STREAM;
else
return VIRTIO_VSOCK_TYPE_SEQPACKET;
}
/* This function can only be used on connecting/connected sockets,
* since a socket assigned to a transport is required.
*
* Do not use on listener sockets!
*/
static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
struct virtio_vsock_pkt_info *info)
{
u32 src_cid, src_port, dst_cid, dst_port;
const struct virtio_transport *t_ops;
struct virtio_vsock_sock *vvs;
u32 pkt_len = info->pkt_len;
u32 rest_len;
int ret;
info->type = virtio_transport_get_type(sk_vsock(vsk));
t_ops = virtio_transport_get_ops(vsk);
if (unlikely(!t_ops))
return -EFAULT;
src_cid = t_ops->transport.get_local_cid();
src_port = vsk->local_addr.svm_port;
if (!info->remote_cid) {
dst_cid = vsk->remote_addr.svm_cid;
dst_port = vsk->remote_addr.svm_port;
} else {
dst_cid = info->remote_cid;
dst_port = info->remote_port;
}
vvs = vsk->trans;
/* virtio_transport_get_credit might return less than pkt_len credit */
pkt_len = virtio_transport_get_credit(vvs, pkt_len);
/* Do not send zero length OP_RW pkt */
if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
return pkt_len;
rest_len = pkt_len;
do {
struct sk_buff *skb;
size_t skb_len;
skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE, rest_len);
skb = virtio_transport_alloc_skb(info, skb_len,
src_cid, src_port,
dst_cid, dst_port);
if (!skb) {
ret = -ENOMEM;
break;
}
virtio_transport_inc_tx_pkt(vvs, skb);
ret = t_ops->send_pkt(skb);
if (ret < 0)
break;
/* Both virtio and vhost 'send_pkt()' returns 'skb_len',
* but for reliability use 'ret' instead of 'skb_len'.
* Also if partial send happens (e.g. 'ret' != 'skb_len')
* somehow, we break this loop, but account such returned
* value in 'virtio_transport_put_credit()'.
*/
rest_len -= ret;
if (WARN_ONCE(ret != skb_len,
"'send_pkt()' returns %i, but %zu expected\n",
ret, skb_len))
break;
} while (rest_len);
virtio_transport_put_credit(vvs, rest_len);
/* Return number of bytes, if any data has been sent. */
if (rest_len != pkt_len)
ret = pkt_len - rest_len;
return ret;
}
static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
u32 len)
{
if (vvs->rx_bytes + len > vvs->buf_alloc)
return false;
vvs->rx_bytes += len;
return true;
}
static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
u32 len)
{
vvs->rx_bytes -= len;
vvs->fwd_cnt += len;
}
void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff *skb)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
spin_lock_bh(&vvs->rx_lock);
vvs->last_fwd_cnt = vvs->fwd_cnt;
hdr->fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
hdr->buf_alloc = cpu_to_le32(vvs->buf_alloc);
spin_unlock_bh(&vvs->rx_lock);
}
EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
{
u32 ret;
if (!credit)
return 0;
spin_lock_bh(&vvs->tx_lock);
ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
if (ret > credit)
ret = credit;
vvs->tx_cnt += ret;
spin_unlock_bh(&vvs->tx_lock);
return ret;
}
EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
{
if (!credit)
return;
spin_lock_bh(&vvs->tx_lock);
vvs->tx_cnt -= credit;
spin_unlock_bh(&vvs->tx_lock);
}
EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
.vsk = vsk,
};
return virtio_transport_send_pkt_info(vsk, &info);
}
static ssize_t
virtio_transport_stream_do_peek(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len)
{
struct virtio_vsock_sock *vvs = vsk->trans;
size_t bytes, total = 0, off;
struct sk_buff *skb, *tmp;
int err = -EFAULT;
spin_lock_bh(&vvs->rx_lock);
skb_queue_walk_safe(&vvs->rx_queue, skb, tmp) {
off = 0;
if (total == len)
break;
while (total < len && off < skb->len) {
bytes = len - total;
if (bytes > skb->len - off)
bytes = skb->len - off;
/* sk_lock is held by caller so no one else can dequeue.
* Unlock rx_lock since memcpy_to_msg() may sleep.
*/
spin_unlock_bh(&vvs->rx_lock);
err = memcpy_to_msg(msg, skb->data + off, bytes);
if (err)
goto out;
spin_lock_bh(&vvs->rx_lock);
total += bytes;
off += bytes;
}
}
spin_unlock_bh(&vvs->rx_lock);
return total;
out:
if (total)
err = total;
return err;
}
static ssize_t
virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len)
{
struct virtio_vsock_sock *vvs = vsk->trans;
size_t bytes, total = 0;
struct sk_buff *skb;
int err = -EFAULT;
u32 free_space;
spin_lock_bh(&vvs->rx_lock);
if (WARN_ONCE(skb_queue_empty(&vvs->rx_queue) && vvs->rx_bytes,
"rx_queue is empty, but rx_bytes is non-zero\n")) {
spin_unlock_bh(&vvs->rx_lock);
return err;
}
while (total < len && !skb_queue_empty(&vvs->rx_queue)) {
skb = skb_peek(&vvs->rx_queue);
bytes = len - total;
if (bytes > skb->len)
bytes = skb->len;
/* sk_lock is held by caller so no one else can dequeue.
* Unlock rx_lock since memcpy_to_msg() may sleep.
*/
spin_unlock_bh(&vvs->rx_lock);
err = memcpy_to_msg(msg, skb->data, bytes);
if (err)
goto out;
spin_lock_bh(&vvs->rx_lock);
total += bytes;
skb_pull(skb, bytes);
if (skb->len == 0) {
u32 pkt_len = le32_to_cpu(virtio_vsock_hdr(skb)->len);
virtio_transport_dec_rx_pkt(vvs, pkt_len);
__skb_unlink(skb, &vvs->rx_queue);
consume_skb(skb);
}
}
free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt);
spin_unlock_bh(&vvs->rx_lock);
/* To reduce the number of credit update messages,
* don't update credits as long as lots of space is available.
* Note: the limit chosen here is arbitrary. Setting the limit
* too high causes extra messages. Too low causes transmitter
* stalls. As stalls are in theory more expensive than extra
* messages, we set the limit to a high value. TODO: experiment
* with different values.
*/
if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
virtio_transport_send_credit_update(vsk);
return total;
out:
if (total)
err = total;
return err;
}
static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags)
{
struct virtio_vsock_sock *vvs = vsk->trans;
int dequeued_len = 0;
size_t user_buf_len = msg_data_left(msg);
bool msg_ready = false;
struct sk_buff *skb;
spin_lock_bh(&vvs->rx_lock);
if (vvs->msg_count == 0) {
spin_unlock_bh(&vvs->rx_lock);
return 0;
}
while (!msg_ready) {
struct virtio_vsock_hdr *hdr;
size_t pkt_len;
skb = __skb_dequeue(&vvs->rx_queue);
if (!skb)
break;
hdr = virtio_vsock_hdr(skb);
pkt_len = (size_t)le32_to_cpu(hdr->len);
if (dequeued_len >= 0) {
size_t bytes_to_copy;
bytes_to_copy = min(user_buf_len, pkt_len);
if (bytes_to_copy) {
int err;
/* sk_lock is held by caller so no one else can dequeue.
* Unlock rx_lock since memcpy_to_msg() may sleep.
*/
spin_unlock_bh(&vvs->rx_lock);
err = memcpy_to_msg(msg, skb->data, bytes_to_copy);
if (err) {
/* Copy of message failed. Rest of
* fragments will be freed without copy.
*/
dequeued_len = err;
} else {
user_buf_len -= bytes_to_copy;
}
spin_lock_bh(&vvs->rx_lock);
}
if (dequeued_len >= 0)
dequeued_len += pkt_len;
}
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
msg_ready = true;
vvs->msg_count--;
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
msg->msg_flags |= MSG_EOR;
}
virtio_transport_dec_rx_pkt(vvs, pkt_len);
kfree_skb(skb);
}
spin_unlock_bh(&vvs->rx_lock);
virtio_transport_send_credit_update(vsk);
return dequeued_len;
}
ssize_t
virtio_transport_stream_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len, int flags)
{
if (flags & MSG_PEEK)
return virtio_transport_stream_do_peek(vsk, msg, len);
else
return virtio_transport_stream_do_dequeue(vsk, msg, len);
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
ssize_t
virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
int flags)
{
if (flags & MSG_PEEK)
return -EOPNOTSUPP;
return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
int
virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len)
{
struct virtio_vsock_sock *vvs = vsk->trans;
spin_lock_bh(&vvs->tx_lock);
if (len > vvs->peer_buf_alloc) {
spin_unlock_bh(&vvs->tx_lock);
return -EMSGSIZE;
}
spin_unlock_bh(&vvs->tx_lock);
return virtio_transport_stream_enqueue(vsk, msg, len);
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
int
virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len, int flags)
{
return -EOPNOTSUPP;
}
EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
s64 bytes;
spin_lock_bh(&vvs->rx_lock);
bytes = vvs->rx_bytes;
spin_unlock_bh(&vvs->rx_lock);
return bytes;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
u32 msg_count;
spin_lock_bh(&vvs->rx_lock);
msg_count = vvs->msg_count;
spin_unlock_bh(&vvs->rx_lock);
return msg_count;
}
EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
static s64 virtio_transport_has_space(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
s64 bytes;
bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
if (bytes < 0)
bytes = 0;
return bytes;
}
s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
s64 bytes;
spin_lock_bh(&vvs->tx_lock);
bytes = virtio_transport_has_space(vsk);
spin_unlock_bh(&vvs->tx_lock);
return bytes;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
int virtio_transport_do_socket_init(struct vsock_sock *vsk,
struct vsock_sock *psk)
{
struct virtio_vsock_sock *vvs;
vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
if (!vvs)
return -ENOMEM;
vsk->trans = vvs;
vvs->vsk = vsk;
if (psk && psk->trans) {
struct virtio_vsock_sock *ptrans = psk->trans;
vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
}
if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
vvs->buf_alloc = vsk->buffer_size;
spin_lock_init(&vvs->rx_lock);
spin_lock_init(&vvs->tx_lock);
skb_queue_head_init(&vvs->rx_queue);
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
/* sk_lock held by the caller */
void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
{
struct virtio_vsock_sock *vvs = vsk->trans;
if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
*val = VIRTIO_VSOCK_MAX_BUF_SIZE;
vvs->buf_alloc = *val;
virtio_transport_send_credit_update(vsk);
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
int
virtio_transport_notify_poll_in(struct vsock_sock *vsk,
size_t target,
bool *data_ready_now)
{
*data_ready_now = vsock_stream_has_data(vsk) >= target;
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
int
virtio_transport_notify_poll_out(struct vsock_sock *vsk,
size_t target,
bool *space_avail_now)
{
s64 free_space;
free_space = vsock_stream_has_space(vsk);
if (free_space > 0)
*space_avail_now = true;
else if (free_space == 0)
*space_avail_now = false;
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
size_t target, struct vsock_transport_recv_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
size_t target, struct vsock_transport_recv_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
size_t target, struct vsock_transport_recv_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
size_t target, ssize_t copied, bool data_read,
struct vsock_transport_recv_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
int virtio_transport_notify_send_init(struct vsock_sock *vsk,
struct vsock_transport_send_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
struct vsock_transport_send_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
struct vsock_transport_send_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
ssize_t written, struct vsock_transport_send_notify_data *data)
{
return 0;
}
EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
{
return vsk->buffer_size;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
{
return true;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
bool virtio_transport_stream_allow(u32 cid, u32 port)
{
return true;
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
int virtio_transport_dgram_bind(struct vsock_sock *vsk,
struct sockaddr_vm *addr)
{
return -EOPNOTSUPP;
}
EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
bool virtio_transport_dgram_allow(u32 cid, u32 port)
{
return false;
}
EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
int virtio_transport_connect(struct vsock_sock *vsk)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_REQUEST,
.vsk = vsk,
};
return virtio_transport_send_pkt_info(vsk, &info);
}
EXPORT_SYMBOL_GPL(virtio_transport_connect);
int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_SHUTDOWN,
.flags = (mode & RCV_SHUTDOWN ?
VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
(mode & SEND_SHUTDOWN ?
VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
.vsk = vsk,
};
return virtio_transport_send_pkt_info(vsk, &info);
}
EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
int
virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
struct sockaddr_vm *remote_addr,
struct msghdr *msg,
size_t dgram_len)
{
return -EOPNOTSUPP;
}
EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
ssize_t
virtio_transport_stream_enqueue(struct vsock_sock *vsk,
struct msghdr *msg,
size_t len)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RW,
.msg = msg,
.pkt_len = len,
.vsk = vsk,
};
return virtio_transport_send_pkt_info(vsk, &info);
}
EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
void virtio_transport_destruct(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
kfree(vvs);
}
EXPORT_SYMBOL_GPL(virtio_transport_destruct);
static int virtio_transport_reset(struct vsock_sock *vsk,
struct sk_buff *skb)
{
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST,
.reply = !!skb,
.vsk = vsk,
};
/* Send RST only if the original pkt is not a RST pkt */
if (skb && le16_to_cpu(virtio_vsock_hdr(skb)->op) == VIRTIO_VSOCK_OP_RST)
return 0;
return virtio_transport_send_pkt_info(vsk, &info);
}
/* Normally packets are associated with a socket. There may be no socket if an
* attempt was made to connect to a socket that does not exist.
*/
static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
struct sk_buff *skb)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST,
.type = le16_to_cpu(hdr->type),
.reply = true,
};
struct sk_buff *reply;
/* Send RST only if the original pkt is not a RST pkt */
if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
return 0;
if (!t)
return -ENOTCONN;
reply = virtio_transport_alloc_skb(&info, 0,
le64_to_cpu(hdr->dst_cid),
le32_to_cpu(hdr->dst_port),
le64_to_cpu(hdr->src_cid),
le32_to_cpu(hdr->src_port));
if (!reply)
return -ENOMEM;
return t->send_pkt(reply);
}
/* This function should be called with sk_lock held and SOCK_DONE set */
static void virtio_transport_remove_sock(struct vsock_sock *vsk)
{
struct virtio_vsock_sock *vvs = vsk->trans;
/* We don't need to take rx_lock, as the socket is closing and we are
* removing it.
*/
__skb_queue_purge(&vvs->rx_queue);
vsock_remove_sock(vsk);
}
static void virtio_transport_wait_close(struct sock *sk, long timeout)
{
if (timeout) {
DEFINE_WAIT_FUNC(wait, woken_wake_function);
add_wait_queue(sk_sleep(sk), &wait);
do {
if (sk_wait_event(sk, &timeout,
sock_flag(sk, SOCK_DONE), &wait))
break;
} while (!signal_pending(current) && timeout);
remove_wait_queue(sk_sleep(sk), &wait);
}
}
static void virtio_transport_do_close(struct vsock_sock *vsk,
bool cancel_timeout)
{
struct sock *sk = sk_vsock(vsk);
sock_set_flag(sk, SOCK_DONE);
vsk->peer_shutdown = SHUTDOWN_MASK;
if (vsock_stream_has_data(vsk) <= 0)
sk->sk_state = TCP_CLOSING;
sk->sk_state_change(sk);
if (vsk->close_work_scheduled &&
(!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
vsk->close_work_scheduled = false;
virtio_transport_remove_sock(vsk);
/* Release refcnt obtained when we scheduled the timeout */
sock_put(sk);
}
}
static void virtio_transport_close_timeout(struct work_struct *work)
{
struct vsock_sock *vsk =
container_of(work, struct vsock_sock, close_work.work);
struct sock *sk = sk_vsock(vsk);
sock_hold(sk);
lock_sock(sk);
if (!sock_flag(sk, SOCK_DONE)) {
(void)virtio_transport_reset(vsk, NULL);
virtio_transport_do_close(vsk, false);
}
vsk->close_work_scheduled = false;
release_sock(sk);
sock_put(sk);
}
/* User context, vsk->sk is locked */
static bool virtio_transport_close(struct vsock_sock *vsk)
{
struct sock *sk = &vsk->sk;
if (!(sk->sk_state == TCP_ESTABLISHED ||
sk->sk_state == TCP_CLOSING))
return true;
/* Already received SHUTDOWN from peer, reply with RST */
if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
(void)virtio_transport_reset(vsk, NULL);
return true;
}
if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
virtio_transport_wait_close(sk, sk->sk_lingertime);
if (sock_flag(sk, SOCK_DONE)) {
return true;
}
sock_hold(sk);
INIT_DELAYED_WORK(&vsk->close_work,
virtio_transport_close_timeout);
vsk->close_work_scheduled = true;
schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
return false;
}
void virtio_transport_release(struct vsock_sock *vsk)
{
struct sock *sk = &vsk->sk;
bool remove_sock = true;
if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
remove_sock = virtio_transport_close(vsk);
if (remove_sock) {
sock_set_flag(sk, SOCK_DONE);
virtio_transport_remove_sock(vsk);
}
}
EXPORT_SYMBOL_GPL(virtio_transport_release);
static int
virtio_transport_recv_connecting(struct sock *sk,
struct sk_buff *skb)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct vsock_sock *vsk = vsock_sk(sk);
int skerr;
int err;
switch (le16_to_cpu(hdr->op)) {
case VIRTIO_VSOCK_OP_RESPONSE:
sk->sk_state = TCP_ESTABLISHED;
sk->sk_socket->state = SS_CONNECTED;
vsock_insert_connected(vsk);
sk->sk_state_change(sk);
break;
case VIRTIO_VSOCK_OP_INVALID:
break;
case VIRTIO_VSOCK_OP_RST:
skerr = ECONNRESET;
err = 0;
goto destroy;
default:
skerr = EPROTO;
err = -EINVAL;
goto destroy;
}
return 0;
destroy:
virtio_transport_reset(vsk, skb);
sk->sk_state = TCP_CLOSE;
sk->sk_err = skerr;
sk_error_report(sk);
return err;
}
static void
virtio_transport_recv_enqueue(struct vsock_sock *vsk,
struct sk_buff *skb)
{
struct virtio_vsock_sock *vvs = vsk->trans;
bool can_enqueue, free_pkt = false;
struct virtio_vsock_hdr *hdr;
u32 len;
hdr = virtio_vsock_hdr(skb);
len = le32_to_cpu(hdr->len);
spin_lock_bh(&vvs->rx_lock);
can_enqueue = virtio_transport_inc_rx_pkt(vvs, len);
if (!can_enqueue) {
free_pkt = true;
goto out;
}
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)
vvs->msg_count++;
/* Try to copy small packets into the buffer of last packet queued,
* to avoid wasting memory queueing the entire buffer with a small
* payload.
*/
if (len <= GOOD_COPY_LEN && !skb_queue_empty(&vvs->rx_queue)) {
struct virtio_vsock_hdr *last_hdr;
struct sk_buff *last_skb;
last_skb = skb_peek_tail(&vvs->rx_queue);
last_hdr = virtio_vsock_hdr(last_skb);
/* If there is space in the last packet queued, we copy the
* new packet in its buffer. We avoid this if the last packet
* queued has VIRTIO_VSOCK_SEQ_EOM set, because this is
* delimiter of SEQPACKET message, so 'pkt' is the first packet
* of a new message.
*/
if (skb->len < skb_tailroom(last_skb) &&
!(le32_to_cpu(last_hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)) {
memcpy(skb_put(last_skb, skb->len), skb->data, skb->len);
free_pkt = true;
last_hdr->flags |= hdr->flags;
le32_add_cpu(&last_hdr->len, len);
goto out;
}
}
__skb_queue_tail(&vvs->rx_queue, skb);
out:
spin_unlock_bh(&vvs->rx_lock);
if (free_pkt)
kfree_skb(skb);
}
static int
virtio_transport_recv_connected(struct sock *sk,
struct sk_buff *skb)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct vsock_sock *vsk = vsock_sk(sk);
int err = 0;
switch (le16_to_cpu(hdr->op)) {
case VIRTIO_VSOCK_OP_RW:
virtio_transport_recv_enqueue(vsk, skb);
vsock_data_ready(sk);
return err;
case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
virtio_transport_send_credit_update(vsk);
break;
case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
sk->sk_write_space(sk);
break;
case VIRTIO_VSOCK_OP_SHUTDOWN:
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
vsk->peer_shutdown |= RCV_SHUTDOWN;
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
vsk->peer_shutdown |= SEND_SHUTDOWN;
if (vsk->peer_shutdown == SHUTDOWN_MASK &&
vsock_stream_has_data(vsk) <= 0 &&
!sock_flag(sk, SOCK_DONE)) {
(void)virtio_transport_reset(vsk, NULL);
virtio_transport_do_close(vsk, true);
}
if (le32_to_cpu(virtio_vsock_hdr(skb)->flags))
sk->sk_state_change(sk);
break;
case VIRTIO_VSOCK_OP_RST:
virtio_transport_do_close(vsk, true);
break;
default:
err = -EINVAL;
break;
}
kfree_skb(skb);
return err;
}
static void
virtio_transport_recv_disconnecting(struct sock *sk,
struct sk_buff *skb)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct vsock_sock *vsk = vsock_sk(sk);
if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
virtio_transport_do_close(vsk, true);
}
static int
virtio_transport_send_response(struct vsock_sock *vsk,
struct sk_buff *skb)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RESPONSE,
.remote_cid = le64_to_cpu(hdr->src_cid),
.remote_port = le32_to_cpu(hdr->src_port),
.reply = true,
.vsk = vsk,
};
return virtio_transport_send_pkt_info(vsk, &info);
}
static bool virtio_transport_space_update(struct sock *sk,
struct sk_buff *skb)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct vsock_sock *vsk = vsock_sk(sk);
struct virtio_vsock_sock *vvs = vsk->trans;
bool space_available;
/* Listener sockets are not associated with any transport, so we are
* not able to take the state to see if there is space available in the
* remote peer, but since they are only used to receive requests, we
* can assume that there is always space available in the other peer.
*/
if (!vvs)
return true;
/* buf_alloc and fwd_cnt is always included in the hdr */
spin_lock_bh(&vvs->tx_lock);
vvs->peer_buf_alloc = le32_to_cpu(hdr->buf_alloc);
vvs->peer_fwd_cnt = le32_to_cpu(hdr->fwd_cnt);
space_available = virtio_transport_has_space(vsk);
spin_unlock_bh(&vvs->tx_lock);
return space_available;
}
/* Handle server socket */
static int
virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
struct virtio_transport *t)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct vsock_sock *vsk = vsock_sk(sk);
struct vsock_sock *vchild;
struct sock *child;
int ret;
if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) {
virtio_transport_reset_no_sock(t, skb);
return -EINVAL;
}
if (sk_acceptq_is_full(sk)) {
virtio_transport_reset_no_sock(t, skb);
return -ENOMEM;
}
child = vsock_create_connected(sk);
if (!child) {
virtio_transport_reset_no_sock(t, skb);
return -ENOMEM;
}
sk_acceptq_added(sk);
lock_sock_nested(child, SINGLE_DEPTH_NESTING);
child->sk_state = TCP_ESTABLISHED;
vchild = vsock_sk(child);
vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid),
le32_to_cpu(hdr->dst_port));
vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid),
le32_to_cpu(hdr->src_port));
ret = vsock_assign_transport(vchild, vsk);
/* Transport assigned (looking at remote_addr) must be the same
* where we received the request.
*/
if (ret || vchild->transport != &t->transport) {
release_sock(child);
virtio_transport_reset_no_sock(t, skb);
sock_put(child);
return ret;
}
if (virtio_transport_space_update(child, skb))
child->sk_write_space(child);
vsock_insert_connected(vchild);
vsock_enqueue_accept(sk, child);
virtio_transport_send_response(vchild, skb);
release_sock(child);
sk->sk_data_ready(sk);
return 0;
}
static bool virtio_transport_valid_type(u16 type)
{
return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
(type == VIRTIO_VSOCK_TYPE_SEQPACKET);
}
/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
* lock.
*/
void virtio_transport_recv_pkt(struct virtio_transport *t,
struct sk_buff *skb)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
struct sockaddr_vm src, dst;
struct vsock_sock *vsk;
struct sock *sk;
bool space_available;
vsock_addr_init(&src, le64_to_cpu(hdr->src_cid),
le32_to_cpu(hdr->src_port));
vsock_addr_init(&dst, le64_to_cpu(hdr->dst_cid),
le32_to_cpu(hdr->dst_port));
trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
dst.svm_cid, dst.svm_port,
le32_to_cpu(hdr->len),
le16_to_cpu(hdr->type),
le16_to_cpu(hdr->op),
le32_to_cpu(hdr->flags),
le32_to_cpu(hdr->buf_alloc),
le32_to_cpu(hdr->fwd_cnt));
if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) {
(void)virtio_transport_reset_no_sock(t, skb);
goto free_pkt;
}
/* The socket must be in connected or bound table
* otherwise send reset back
*/
sk = vsock_find_connected_socket(&src, &dst);
if (!sk) {
sk = vsock_find_bound_socket(&dst);
if (!sk) {
(void)virtio_transport_reset_no_sock(t, skb);
goto free_pkt;
}
}
if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) {
(void)virtio_transport_reset_no_sock(t, skb);
sock_put(sk);
goto free_pkt;
}
if (!skb_set_owner_sk_safe(skb, sk)) {
WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n");
goto free_pkt;
}
vsk = vsock_sk(sk);
lock_sock(sk);
/* Check if sk has been closed before lock_sock */
if (sock_flag(sk, SOCK_DONE)) {
(void)virtio_transport_reset_no_sock(t, skb);
release_sock(sk);
sock_put(sk);
goto free_pkt;
}
space_available = virtio_transport_space_update(sk, skb);
/* Update CID in case it has changed after a transport reset event */
if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
vsk->local_addr.svm_cid = dst.svm_cid;
if (space_available)
sk->sk_write_space(sk);
switch (sk->sk_state) {
case TCP_LISTEN:
virtio_transport_recv_listen(sk, skb, t);
kfree_skb(skb);
break;
case TCP_SYN_SENT:
virtio_transport_recv_connecting(sk, skb);
kfree_skb(skb);
break;
case TCP_ESTABLISHED:
virtio_transport_recv_connected(sk, skb);
break;
case TCP_CLOSING:
virtio_transport_recv_disconnecting(sk, skb);
kfree_skb(skb);
break;
default:
(void)virtio_transport_reset_no_sock(t, skb);
kfree_skb(skb);
break;
}
release_sock(sk);
/* Release refcnt obtained when we fetched this socket out of the
* bound or connected list.
*/
sock_put(sk);
return;
free_pkt:
kfree_skb(skb);
}
EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
/* Remove skbs found in a queue that have a vsk that matches.
*
* Each skb is freed.
*
* Returns the count of skbs that were reply packets.
*/
int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue)
{
struct sk_buff_head freeme;
struct sk_buff *skb, *tmp;
int cnt = 0;
skb_queue_head_init(&freeme);
spin_lock_bh(&queue->lock);
skb_queue_walk_safe(queue, skb, tmp) {
if (vsock_sk(skb->sk) != vsk)
continue;
__skb_unlink(skb, queue);
__skb_queue_tail(&freeme, skb);
if (virtio_vsock_skb_reply(skb))
cnt++;
}
spin_unlock_bh(&queue->lock);
__skb_queue_purge(&freeme);
return cnt;
}
EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs);
int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor)
{
struct virtio_vsock_sock *vvs = vsk->trans;
struct sock *sk = sk_vsock(vsk);
struct sk_buff *skb;
int off = 0;
int err;
spin_lock_bh(&vvs->rx_lock);
/* Use __skb_recv_datagram() for race-free handling of the receive. It
* works for types other than dgrams.
*/
skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err);
spin_unlock_bh(&vvs->rx_lock);
if (!skb)
return err;
return recv_actor(sk, skb);
}
EXPORT_SYMBOL_GPL(virtio_transport_read_skb);
MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Asias He");
MODULE_DESCRIPTION("common code for virtio vsock");