Merge branch 'tls-rx-refactoring-part-2'
Jakub Kicinski says: ==================== tls: rx: random refactoring part 2 TLS Rx refactoring. Part 2 of 3. This one focusing on the main loop. A couple of features to follow. ====================
This commit is contained in:
commit
516a2f1f6f
@ -152,7 +152,6 @@ struct tls_sw_context_rx {
|
||||
atomic_t decrypt_pending;
|
||||
/* protect crypto_wait with decrypt_pending*/
|
||||
spinlock_t decrypt_compl_lock;
|
||||
bool async_notify;
|
||||
};
|
||||
|
||||
struct tls_record_info {
|
||||
|
258
net/tls/tls_sw.c
258
net/tls/tls_sw.c
@ -44,6 +44,11 @@
|
||||
#include <net/strparser.h>
|
||||
#include <net/tls.h>
|
||||
|
||||
struct tls_decrypt_arg {
|
||||
bool zc;
|
||||
bool async;
|
||||
};
|
||||
|
||||
noinline void tls_err_abort(struct sock *sk, int err)
|
||||
{
|
||||
WARN_ON_ONCE(err >= 0);
|
||||
@ -168,7 +173,6 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
|
||||
struct scatterlist *sg;
|
||||
struct sk_buff *skb;
|
||||
unsigned int pages;
|
||||
int pending;
|
||||
|
||||
skb = (struct sk_buff *)req->data;
|
||||
tls_ctx = tls_get_ctx(skb->sk);
|
||||
@ -216,9 +220,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
|
||||
kfree(aead_req);
|
||||
|
||||
spin_lock_bh(&ctx->decrypt_compl_lock);
|
||||
pending = atomic_dec_return(&ctx->decrypt_pending);
|
||||
|
||||
if (!pending && ctx->async_notify)
|
||||
if (!atomic_dec_return(&ctx->decrypt_pending))
|
||||
complete(&ctx->async_wait.completion);
|
||||
spin_unlock_bh(&ctx->decrypt_compl_lock);
|
||||
}
|
||||
@ -1345,15 +1347,14 @@ static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
|
||||
return skb;
|
||||
}
|
||||
|
||||
static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
|
||||
static int tls_setup_from_iter(struct iov_iter *from,
|
||||
int length, int *pages_used,
|
||||
unsigned int *size_used,
|
||||
struct scatterlist *to,
|
||||
int to_max_pages)
|
||||
{
|
||||
int rc = 0, i = 0, num_elem = *pages_used, maxpages;
|
||||
struct page *pages[MAX_SKB_FRAGS];
|
||||
unsigned int size = *size_used;
|
||||
unsigned int size = 0;
|
||||
ssize_t copied, use;
|
||||
size_t offset;
|
||||
|
||||
@ -1396,8 +1397,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
|
||||
sg_mark_end(&to[num_elem - 1]);
|
||||
out:
|
||||
if (rc)
|
||||
iov_iter_revert(from, size - *size_used);
|
||||
*size_used = size;
|
||||
iov_iter_revert(from, size);
|
||||
*pages_used = num_elem;
|
||||
|
||||
return rc;
|
||||
@ -1414,7 +1414,7 @@ out:
|
||||
static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
|
||||
struct iov_iter *out_iov,
|
||||
struct scatterlist *out_sg,
|
||||
int *chunk, bool *zc, bool async)
|
||||
struct tls_decrypt_arg *darg)
|
||||
{
|
||||
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
||||
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
||||
@ -1431,7 +1431,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
|
||||
prot->tail_size;
|
||||
int iv_offset = 0;
|
||||
|
||||
if (*zc && (out_iov || out_sg)) {
|
||||
if (darg->zc && (out_iov || out_sg)) {
|
||||
if (out_iov)
|
||||
n_sgout = 1 +
|
||||
iov_iter_npages_cap(out_iov, INT_MAX, data_len);
|
||||
@ -1441,7 +1441,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
|
||||
rxm->full_len - prot->prepend_size);
|
||||
} else {
|
||||
n_sgout = 0;
|
||||
*zc = false;
|
||||
darg->zc = false;
|
||||
n_sgin = skb_cow_data(skb, 0, &unused);
|
||||
}
|
||||
|
||||
@ -1523,9 +1523,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
|
||||
sg_init_table(sgout, n_sgout);
|
||||
sg_set_buf(&sgout[0], aad, prot->aad_size);
|
||||
|
||||
*chunk = 0;
|
||||
err = tls_setup_from_iter(sk, out_iov, data_len,
|
||||
&pages, chunk, &sgout[1],
|
||||
err = tls_setup_from_iter(out_iov, data_len,
|
||||
&pages, &sgout[1],
|
||||
(n_sgout - 1));
|
||||
if (err < 0)
|
||||
goto fallback_to_reg_recv;
|
||||
@ -1538,13 +1537,12 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
|
||||
fallback_to_reg_recv:
|
||||
sgout = sgin;
|
||||
pages = 0;
|
||||
*chunk = data_len;
|
||||
*zc = false;
|
||||
darg->zc = false;
|
||||
}
|
||||
|
||||
/* Prepare and submit AEAD request */
|
||||
err = tls_do_decryption(sk, skb, sgin, sgout, iv,
|
||||
data_len, aead_req, async);
|
||||
data_len, aead_req, darg->async);
|
||||
if (err == -EINPROGRESS)
|
||||
return err;
|
||||
|
||||
@ -1557,8 +1555,8 @@ fallback_to_reg_recv:
|
||||
}
|
||||
|
||||
static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
|
||||
struct iov_iter *dest, int *chunk, bool *zc,
|
||||
bool async)
|
||||
struct iov_iter *dest,
|
||||
struct tls_decrypt_arg *darg)
|
||||
{
|
||||
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
||||
struct tls_prot_info *prot = &tls_ctx->prot_info;
|
||||
@ -1567,7 +1565,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
|
||||
int pad, err;
|
||||
|
||||
if (tlm->decrypted) {
|
||||
*zc = false;
|
||||
darg->zc = false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -1577,12 +1575,12 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
|
||||
return err;
|
||||
if (err > 0) {
|
||||
tlm->decrypted = 1;
|
||||
*zc = false;
|
||||
darg->zc = false;
|
||||
goto decrypt_done;
|
||||
}
|
||||
}
|
||||
|
||||
err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
|
||||
err = decrypt_internal(sk, skb, dest, NULL, darg);
|
||||
if (err < 0) {
|
||||
if (err == -EINPROGRESS)
|
||||
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
|
||||
@ -1608,34 +1606,32 @@ decrypt_done:
|
||||
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
|
||||
struct scatterlist *sgout)
|
||||
{
|
||||
bool zc = true;
|
||||
int chunk;
|
||||
struct tls_decrypt_arg darg = { .zc = true, };
|
||||
|
||||
return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
|
||||
return decrypt_internal(sk, skb, NULL, sgout, &darg);
|
||||
}
|
||||
|
||||
static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
|
||||
unsigned int len)
|
||||
static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
|
||||
u8 *control)
|
||||
{
|
||||
struct tls_context *tls_ctx = tls_get_ctx(sk);
|
||||
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
||||
int err;
|
||||
|
||||
if (skb) {
|
||||
struct strp_msg *rxm = strp_msg(skb);
|
||||
if (!*control) {
|
||||
*control = tlm->control;
|
||||
if (!*control)
|
||||
return -EBADMSG;
|
||||
|
||||
if (len < rxm->full_len) {
|
||||
rxm->offset += len;
|
||||
rxm->full_len -= len;
|
||||
return false;
|
||||
err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
|
||||
sizeof(*control), control);
|
||||
if (*control != TLS_RECORD_TYPE_DATA) {
|
||||
if (err || msg->msg_flags & MSG_CTRUNC)
|
||||
return -EIO;
|
||||
}
|
||||
consume_skb(skb);
|
||||
} else if (*control != tlm->control) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Finished with message */
|
||||
ctx->recv_pkt = NULL;
|
||||
__strp_unpause(&ctx->strp);
|
||||
|
||||
return true;
|
||||
return 1;
|
||||
}
|
||||
|
||||
/* This function traverses the rx_list in tls receive context to copies the
|
||||
@ -1646,31 +1642,23 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
|
||||
static int process_rx_list(struct tls_sw_context_rx *ctx,
|
||||
struct msghdr *msg,
|
||||
u8 *control,
|
||||
bool *cmsg,
|
||||
size_t skip,
|
||||
size_t len,
|
||||
bool zc,
|
||||
bool is_peek)
|
||||
{
|
||||
struct sk_buff *skb = skb_peek(&ctx->rx_list);
|
||||
u8 ctrl = *control;
|
||||
u8 msgc = *cmsg;
|
||||
struct tls_msg *tlm;
|
||||
ssize_t copied = 0;
|
||||
|
||||
/* Set the record type in 'control' if caller didn't pass it */
|
||||
if (!ctrl && skb) {
|
||||
tlm = tls_msg(skb);
|
||||
ctrl = tlm->control;
|
||||
}
|
||||
int err;
|
||||
|
||||
while (skip && skb) {
|
||||
struct strp_msg *rxm = strp_msg(skb);
|
||||
tlm = tls_msg(skb);
|
||||
|
||||
/* Cannot process a record of different type */
|
||||
if (ctrl != tlm->control)
|
||||
return 0;
|
||||
err = tls_record_content_type(msg, tlm, control);
|
||||
if (err <= 0)
|
||||
return err;
|
||||
|
||||
if (skip < rxm->full_len)
|
||||
break;
|
||||
@ -1686,27 +1674,12 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
|
||||
|
||||
tlm = tls_msg(skb);
|
||||
|
||||
/* Cannot process a record of different type */
|
||||
if (ctrl != tlm->control)
|
||||
return 0;
|
||||
|
||||
/* Set record type if not already done. For a non-data record,
|
||||
* do not proceed if record type could not be copied.
|
||||
*/
|
||||
if (!msgc) {
|
||||
int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
|
||||
sizeof(ctrl), &ctrl);
|
||||
msgc = true;
|
||||
if (ctrl != TLS_RECORD_TYPE_DATA) {
|
||||
if (cerr || msg->msg_flags & MSG_CTRUNC)
|
||||
return -EIO;
|
||||
|
||||
*cmsg = msgc;
|
||||
}
|
||||
}
|
||||
err = tls_record_content_type(msg, tlm, control);
|
||||
if (err <= 0)
|
||||
return err;
|
||||
|
||||
if (!zc || (rxm->full_len - skip) > len) {
|
||||
int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
|
||||
err = skb_copy_datagram_msg(skb, rxm->offset + skip,
|
||||
msg, chunk);
|
||||
if (err < 0)
|
||||
return err;
|
||||
@ -1743,7 +1716,6 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
|
||||
skb = next_skb;
|
||||
}
|
||||
|
||||
*control = ctrl;
|
||||
return copied;
|
||||
}
|
||||
|
||||
@ -1758,19 +1730,19 @@ int tls_sw_recvmsg(struct sock *sk,
|
||||
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
|
||||
struct tls_prot_info *prot = &tls_ctx->prot_info;
|
||||
struct sk_psock *psock;
|
||||
int num_async, pending;
|
||||
unsigned char control = 0;
|
||||
ssize_t decrypted = 0;
|
||||
struct strp_msg *rxm;
|
||||
struct tls_msg *tlm;
|
||||
struct sk_buff *skb;
|
||||
ssize_t copied = 0;
|
||||
bool cmsg = false;
|
||||
bool async = false;
|
||||
int target, err = 0;
|
||||
long timeo;
|
||||
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
|
||||
bool is_peek = flags & MSG_PEEK;
|
||||
bool bpf_strp_enabled;
|
||||
bool zc_capable;
|
||||
|
||||
flags |= nonblock;
|
||||
|
||||
@ -1782,8 +1754,7 @@ int tls_sw_recvmsg(struct sock *sk,
|
||||
bpf_strp_enabled = sk_psock_strp_enabled(psock);
|
||||
|
||||
/* Process pending decrypted records. It must be non-zero-copy */
|
||||
err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
|
||||
is_peek);
|
||||
err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
|
||||
if (err < 0) {
|
||||
tls_err_abort(sk, err);
|
||||
goto end;
|
||||
@ -1797,15 +1768,12 @@ int tls_sw_recvmsg(struct sock *sk,
|
||||
len = len - copied;
|
||||
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
|
||||
|
||||
zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
|
||||
prot->version != TLS_1_3_VERSION;
|
||||
decrypted = 0;
|
||||
num_async = 0;
|
||||
while (len && (decrypted + copied < target || ctx->recv_pkt)) {
|
||||
bool retain_skb = false;
|
||||
bool zc = false;
|
||||
int to_decrypt;
|
||||
int chunk = 0;
|
||||
bool async_capable;
|
||||
bool async = false;
|
||||
struct tls_decrypt_arg darg = {};
|
||||
int to_decrypt, chunk;
|
||||
|
||||
skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
|
||||
if (!skb) {
|
||||
@ -1827,29 +1795,24 @@ int tls_sw_recvmsg(struct sock *sk,
|
||||
|
||||
to_decrypt = rxm->full_len - prot->overhead_size;
|
||||
|
||||
if (to_decrypt <= len && !is_kvec && !is_peek &&
|
||||
tlm->control == TLS_RECORD_TYPE_DATA &&
|
||||
prot->version != TLS_1_3_VERSION &&
|
||||
!bpf_strp_enabled)
|
||||
zc = true;
|
||||
if (zc_capable && to_decrypt <= len &&
|
||||
tlm->control == TLS_RECORD_TYPE_DATA)
|
||||
darg.zc = true;
|
||||
|
||||
/* Do not use async mode if record is non-data */
|
||||
if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
|
||||
async_capable = ctx->async_capable;
|
||||
darg.async = ctx->async_capable;
|
||||
else
|
||||
async_capable = false;
|
||||
darg.async = false;
|
||||
|
||||
err = decrypt_skb_update(sk, skb, &msg->msg_iter,
|
||||
&chunk, &zc, async_capable);
|
||||
err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg);
|
||||
if (err < 0 && err != -EINPROGRESS) {
|
||||
tls_err_abort(sk, -EBADMSG);
|
||||
goto recv_end;
|
||||
}
|
||||
|
||||
if (err == -EINPROGRESS) {
|
||||
if (err == -EINPROGRESS)
|
||||
async = true;
|
||||
num_async++;
|
||||
}
|
||||
|
||||
/* If the type of records being processed is not known yet,
|
||||
* set it to record type just dequeued. If it is already known,
|
||||
@ -1858,92 +1821,79 @@ int tls_sw_recvmsg(struct sock *sk,
|
||||
* is known just after record is dequeued from stream parser.
|
||||
* For tls1.3, we disable async.
|
||||
*/
|
||||
|
||||
if (!control)
|
||||
control = tlm->control;
|
||||
else if (control != tlm->control)
|
||||
err = tls_record_content_type(msg, tlm, &control);
|
||||
if (err <= 0)
|
||||
goto recv_end;
|
||||
|
||||
if (!cmsg) {
|
||||
int cerr;
|
||||
ctx->recv_pkt = NULL;
|
||||
__strp_unpause(&ctx->strp);
|
||||
skb_queue_tail(&ctx->rx_list, skb);
|
||||
|
||||
cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
|
||||
sizeof(control), &control);
|
||||
cmsg = true;
|
||||
if (control != TLS_RECORD_TYPE_DATA) {
|
||||
if (cerr || msg->msg_flags & MSG_CTRUNC) {
|
||||
err = -EIO;
|
||||
goto recv_end;
|
||||
}
|
||||
}
|
||||
if (async) {
|
||||
/* TLS 1.2-only, to_decrypt must be text length */
|
||||
chunk = min_t(int, to_decrypt, len);
|
||||
leave_on_list:
|
||||
decrypted += chunk;
|
||||
len -= chunk;
|
||||
continue;
|
||||
}
|
||||
/* TLS 1.3 may have updated the length by more than overhead */
|
||||
chunk = rxm->full_len;
|
||||
|
||||
if (async)
|
||||
goto pick_next_record;
|
||||
if (!darg.zc) {
|
||||
bool partially_consumed = chunk > len;
|
||||
|
||||
if (!zc) {
|
||||
if (bpf_strp_enabled) {
|
||||
err = sk_psock_tls_strp_read(psock, skb);
|
||||
if (err != __SK_PASS) {
|
||||
rxm->offset = rxm->offset + rxm->full_len;
|
||||
rxm->full_len = 0;
|
||||
skb_unlink(skb, &ctx->rx_list);
|
||||
if (err == __SK_DROP)
|
||||
consume_skb(skb);
|
||||
ctx->recv_pkt = NULL;
|
||||
__strp_unpause(&ctx->strp);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (rxm->full_len > len) {
|
||||
retain_skb = true;
|
||||
if (partially_consumed)
|
||||
chunk = len;
|
||||
} else {
|
||||
chunk = rxm->full_len;
|
||||
}
|
||||
|
||||
err = skb_copy_datagram_msg(skb, rxm->offset,
|
||||
msg, chunk);
|
||||
if (err < 0)
|
||||
goto recv_end;
|
||||
|
||||
if (!is_peek) {
|
||||
rxm->offset = rxm->offset + chunk;
|
||||
rxm->full_len = rxm->full_len - chunk;
|
||||
if (is_peek)
|
||||
goto leave_on_list;
|
||||
|
||||
if (partially_consumed) {
|
||||
rxm->offset += chunk;
|
||||
rxm->full_len -= chunk;
|
||||
goto leave_on_list;
|
||||
}
|
||||
}
|
||||
|
||||
pick_next_record:
|
||||
if (chunk > len)
|
||||
chunk = len;
|
||||
|
||||
decrypted += chunk;
|
||||
len -= chunk;
|
||||
|
||||
/* For async or peek case, queue the current skb */
|
||||
if (async || is_peek || retain_skb) {
|
||||
skb_queue_tail(&ctx->rx_list, skb);
|
||||
skb = NULL;
|
||||
}
|
||||
skb_unlink(skb, &ctx->rx_list);
|
||||
consume_skb(skb);
|
||||
|
||||
if (tls_sw_advance_skb(sk, skb, chunk)) {
|
||||
/* Return full control message to
|
||||
* userspace before trying to parse
|
||||
* another message type
|
||||
*/
|
||||
msg->msg_flags |= MSG_EOR;
|
||||
if (control != TLS_RECORD_TYPE_DATA)
|
||||
goto recv_end;
|
||||
} else {
|
||||
/* Return full control message to userspace before trying
|
||||
* to parse another message type
|
||||
*/
|
||||
msg->msg_flags |= MSG_EOR;
|
||||
if (control != TLS_RECORD_TYPE_DATA)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
recv_end:
|
||||
if (num_async) {
|
||||
if (async) {
|
||||
int pending;
|
||||
|
||||
/* Wait for all previously submitted records to be decrypted */
|
||||
spin_lock_bh(&ctx->decrypt_compl_lock);
|
||||
ctx->async_notify = true;
|
||||
reinit_completion(&ctx->async_wait.completion);
|
||||
pending = atomic_read(&ctx->decrypt_pending);
|
||||
spin_unlock_bh(&ctx->decrypt_compl_lock);
|
||||
if (pending) {
|
||||
@ -1955,21 +1905,14 @@ recv_end:
|
||||
decrypted = 0;
|
||||
goto end;
|
||||
}
|
||||
} else {
|
||||
reinit_completion(&ctx->async_wait.completion);
|
||||
}
|
||||
|
||||
/* There can be no concurrent accesses, since we have no
|
||||
* pending decrypt operations
|
||||
*/
|
||||
WRITE_ONCE(ctx->async_notify, false);
|
||||
|
||||
/* Drain records from the rx_list & copy if required */
|
||||
if (is_peek || is_kvec)
|
||||
err = process_rx_list(ctx, msg, &control, &cmsg, copied,
|
||||
err = process_rx_list(ctx, msg, &control, copied,
|
||||
decrypted, false, is_peek);
|
||||
else
|
||||
err = process_rx_list(ctx, msg, &control, &cmsg, 0,
|
||||
err = process_rx_list(ctx, msg, &control, 0,
|
||||
decrypted, true, is_peek);
|
||||
if (err < 0) {
|
||||
tls_err_abort(sk, err);
|
||||
@ -2003,7 +1946,6 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
|
||||
int err = 0;
|
||||
long timeo;
|
||||
int chunk;
|
||||
bool zc = false;
|
||||
|
||||
lock_sock(sk);
|
||||
|
||||
@ -2013,12 +1955,14 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
|
||||
if (from_queue) {
|
||||
skb = __skb_dequeue(&ctx->rx_list);
|
||||
} else {
|
||||
struct tls_decrypt_arg darg = {};
|
||||
|
||||
skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
|
||||
&err);
|
||||
if (!skb)
|
||||
goto splice_read_end;
|
||||
|
||||
err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
|
||||
err = decrypt_skb_update(sk, skb, NULL, &darg);
|
||||
if (err < 0) {
|
||||
tls_err_abort(sk, -EBADMSG);
|
||||
goto splice_read_end;
|
||||
|
Loading…
x
Reference in New Issue
Block a user