diff --git a/include/net/strparser.h b/include/net/strparser.h index 88900b05443e..41e2ce9e9e10 100644 --- a/include/net/strparser.h +++ b/include/net/strparser.h @@ -72,7 +72,6 @@ struct sk_skb_cb { /* strp users' data follows */ struct tls_msg { u8 control; - u8 decrypted; } tls; /* temp_reg is a temporary register used for bpf_convert_data_end_access * when dst_reg == src_reg. diff --git a/include/net/tls.h b/include/net/tls.h index 8742e13bc362..181c496b01b8 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -116,11 +116,15 @@ struct tls_sw_context_rx { void (*saved_data_ready)(struct sock *sk); struct sk_buff *recv_pkt; + u8 reader_present; u8 async_capable:1; u8 zc_capable:1; + u8 reader_contended:1; atomic_t decrypt_pending; /* protect crypto_wait with decrypt_pending*/ spinlock_t decrypt_compl_lock; + struct sk_buff_head async_hold; + struct wait_queue_head wq; }; struct tls_record_info { diff --git a/net/tls/Makefile b/net/tls/Makefile index f1ffbfe8968d..e41c800489ac 100644 --- a/net/tls/Makefile +++ b/net/tls/Makefile @@ -7,7 +7,7 @@ CFLAGS_trace.o := -I$(src) obj-$(CONFIG_TLS) += tls.o -tls-y := tls_main.o tls_sw.o tls_proc.o trace.o +tls-y := tls_main.o tls_sw.o tls_proc.o trace.o tls_strp.o tls-$(CONFIG_TLS_TOE) += tls_toe.o tls-$(CONFIG_TLS_DEVICE) += tls_device.o tls_device_fallback.o diff --git a/net/tls/tls.h b/net/tls/tls.h index e0ccc96a0850..3740740504e3 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -39,6 +39,9 @@ #include #include +#define TLS_PAGE_ORDER (min_t(unsigned int, PAGE_ALLOC_COSTLY_ORDER, \ + TLS_MAX_PAYLOAD_SIZE >> PAGE_SHIFT)) + #define __TLS_INC_STATS(net, field) \ __SNMP_INC_STATS((net)->mib.tls_statistics, field) #define TLS_INC_STATS(net, field) \ @@ -118,13 +121,15 @@ void tls_device_write_space(struct sock *sk, struct tls_context *ctx); int tls_process_cmsg(struct sock *sk, struct msghdr *msg, unsigned char *record_type); -int decrypt_skb(struct sock *sk, struct sk_buff *skb, - struct scatterlist *sgout); +int decrypt_skb(struct sock *sk, struct scatterlist *sgout); int tls_sw_fallback_init(struct sock *sk, struct tls_offload_context_tx *offload_ctx, struct tls_crypto_info *crypto_info); +int tls_strp_msg_hold(struct sock *sk, struct sk_buff *skb, + struct sk_buff_head *dst); + static inline struct tls_msg *tls_msg(struct sk_buff *skb) { struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb; @@ -132,6 +137,11 @@ static inline struct tls_msg *tls_msg(struct sk_buff *skb) return &scb->tls; } +static inline struct sk_buff *tls_strp_msg(struct tls_sw_context_rx *ctx) +{ + return ctx->recv_pkt; +} + #ifdef CONFIG_TLS_DEVICE int tls_device_init(void); void tls_device_cleanup(void); @@ -140,8 +150,7 @@ void tls_device_free_resources_tx(struct sock *sk); int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx); void tls_device_offload_cleanup_rx(struct sock *sk); void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq); -int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, - struct sk_buff *skb, struct strp_msg *rxm); +int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx); #else static inline int tls_device_init(void) { return 0; } static inline void tls_device_cleanup(void) {} @@ -165,8 +174,7 @@ static inline void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {} static inline int -tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, - struct sk_buff *skb, struct strp_msg *rxm) +tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx) { return 0; } diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 6f9dff1c11f6..6abbe3c2520c 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -889,14 +889,19 @@ static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx, } } -static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb) +static int +tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx) { - struct strp_msg *rxm = strp_msg(skb); - int err = 0, offset = rxm->offset, copy, nsg, data_len, pos; - struct sk_buff *skb_iter, *unused; + int err = 0, offset, copy, nsg, data_len, pos; + struct sk_buff *skb, *skb_iter, *unused; struct scatterlist sg[1]; + struct strp_msg *rxm; char *orig_buf, *buf; + skb = tls_strp_msg(sw_ctx); + rxm = strp_msg(skb); + offset = rxm->offset; + orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation); if (!orig_buf) @@ -919,7 +924,7 @@ static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb) goto free_buf; /* We are interested only in the decrypted data not the auth */ - err = decrypt_skb(sk, skb, sg); + err = decrypt_skb(sk, sg); if (err != -EBADMSG) goto free_buf; else @@ -974,10 +979,12 @@ free_buf: return err; } -int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, - struct sk_buff *skb, struct strp_msg *rxm) +int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx) { struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx); + struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx); + struct sk_buff *skb = tls_strp_msg(sw_ctx); + struct strp_msg *rxm = strp_msg(skb); int is_decrypted = skb->decrypted; int is_encrypted = !is_decrypted; struct sk_buff *skb_iter; @@ -1000,7 +1007,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, * likely have initial fragments decrypted, and final ones not * decrypted. We need to reencrypt that single SKB. */ - return tls_device_reencrypt(sk, skb); + return tls_device_reencrypt(sk, sw_ctx); } /* Return immediately if the record is either entirely plaintext or @@ -1017,7 +1024,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, } ctx->resync_nh_reset = 1; - return tls_device_reencrypt(sk, skb); + return tls_device_reencrypt(sk, sw_ctx); } static void tls_device_attach(struct tls_context *ctx, struct sock *sk, diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c new file mode 100644 index 000000000000..9ccab79a6e1e --- /dev/null +++ b/net/tls/tls_strp.c @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include + +#include "tls.h" + +int tls_strp_msg_hold(struct sock *sk, struct sk_buff *skb, + struct sk_buff_head *dst) +{ + struct sk_buff *clone; + + clone = skb_clone(skb, sk->sk_allocation); + if (!clone) + return -ENOMEM; + __skb_queue_tail(dst, clone); + return 0; +} diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 68d79ee48a56..859ea02022c0 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -47,9 +47,13 @@ #include "tls.h" struct tls_decrypt_arg { + struct_group(inargs, bool zc; bool async; u8 tail; + ); + + struct sk_buff *skb; }; struct tls_decrypt_ctx { @@ -180,39 +184,22 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) struct scatterlist *sgin = aead_req->src; struct tls_sw_context_rx *ctx; struct tls_context *tls_ctx; - struct tls_prot_info *prot; struct scatterlist *sg; - struct sk_buff *skb; unsigned int pages; + struct sock *sk; - skb = (struct sk_buff *)req->data; - tls_ctx = tls_get_ctx(skb->sk); + sk = (struct sock *)req->data; + tls_ctx = tls_get_ctx(sk); ctx = tls_sw_ctx_rx(tls_ctx); - prot = &tls_ctx->prot_info; /* Propagate if there was an err */ if (err) { if (err == -EBADMSG) - TLS_INC_STATS(sock_net(skb->sk), - LINUX_MIB_TLSDECRYPTERROR); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); ctx->async_wait.err = err; - tls_err_abort(skb->sk, err); - } else { - struct strp_msg *rxm = strp_msg(skb); - - /* No TLS 1.3 support with async crypto */ - WARN_ON(prot->tail_size); - - rxm->offset += prot->prepend_size; - rxm->full_len -= prot->overhead_size; + tls_err_abort(sk, err); } - /* After using skb->sk to propagate sk through crypto async callback - * we need to NULL it again. - */ - skb->sk = NULL; - - /* Free the destination pages if skb was not decrypted inplace */ if (sgout != sgin) { /* Skip the first S/G entry as it points to AAD */ @@ -232,7 +219,6 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) } static int tls_do_decryption(struct sock *sk, - struct sk_buff *skb, struct scatterlist *sgin, struct scatterlist *sgout, char *iv_recv, @@ -252,16 +238,9 @@ static int tls_do_decryption(struct sock *sk, (u8 *)iv_recv); if (darg->async) { - /* Using skb->sk to push sk through to crypto async callback - * handler. This allows propagating errors up to the socket - * if needed. It _must_ be cleared in the async handler - * before consume_skb is called. We _know_ skb->sk is NULL - * because it is a clone from strparser. - */ - skb->sk = sk; aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, - tls_decrypt_done, skb); + tls_decrypt_done, sk); atomic_inc(&ctx->decrypt_pending); } else { aead_request_set_callback(aead_req, @@ -1404,51 +1383,90 @@ out: return rc; } +static struct sk_buff * +tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb, + unsigned int full_len) +{ + struct strp_msg *clr_rxm; + struct sk_buff *clr_skb; + int err; + + clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER, + &err, sk->sk_allocation); + if (!clr_skb) + return NULL; + + skb_copy_header(clr_skb, skb); + clr_skb->len = full_len; + clr_skb->data_len = full_len; + + clr_rxm = strp_msg(clr_skb); + clr_rxm->offset = 0; + + return clr_skb; +} + +/* Decrypt handlers + * + * tls_decrypt_sg() and tls_decrypt_device() are decrypt handlers. + * They must transform the darg in/out argument are as follows: + * | Input | Output + * ------------------------------------------------------------------- + * zc | Zero-copy decrypt allowed | Zero-copy performed + * async | Async decrypt allowed | Async crypto used / in progress + * skb | * | Output skb + */ + /* This function decrypts the input skb into either out_iov or in out_sg - * or in skb buffers itself. The input parameter 'zc' indicates if + * or in skb buffers itself. The input parameter 'darg->zc' indicates if * zero-copy mode needs to be tried or not. With zero-copy mode, either * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are * NULL, then the decryption happens inside skb buffers itself, i.e. - * zero-copy gets disabled and 'zc' is updated. + * zero-copy gets disabled and 'darg->zc' is updated. */ - -static int decrypt_internal(struct sock *sk, struct sk_buff *skb, - struct iov_iter *out_iov, - struct scatterlist *out_sg, - struct tls_decrypt_arg *darg) +static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, + struct scatterlist *out_sg, + struct tls_decrypt_arg *darg) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_prot_info *prot = &tls_ctx->prot_info; int n_sgin, n_sgout, aead_size, err, pages = 0; - struct strp_msg *rxm = strp_msg(skb); - struct tls_msg *tlm = tls_msg(skb); + struct sk_buff *skb = tls_strp_msg(ctx); + const struct strp_msg *rxm = strp_msg(skb); + const struct tls_msg *tlm = tls_msg(skb); struct aead_request *aead_req; - struct sk_buff *unused; struct scatterlist *sgin = NULL; struct scatterlist *sgout = NULL; const int data_len = rxm->full_len - prot->overhead_size; int tail_pages = !!prot->tail_size; struct tls_decrypt_ctx *dctx; + struct sk_buff *clear_skb; int iv_offset = 0; u8 *mem; + n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size, + rxm->full_len - prot->prepend_size); + if (n_sgin < 1) + return n_sgin ?: -EBADMSG; + if (darg->zc && (out_iov || out_sg)) { + clear_skb = NULL; + if (out_iov) n_sgout = 1 + tail_pages + iov_iter_npages_cap(out_iov, INT_MAX, data_len); else n_sgout = sg_nents(out_sg); - n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size, - rxm->full_len - prot->prepend_size); } else { - n_sgout = 0; darg->zc = false; - n_sgin = skb_cow_data(skb, 0, &unused); - } - if (n_sgin < 1) - return -EBADMSG; + clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len); + if (!clear_skb) + return -ENOMEM; + + n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags; + } /* Increment to accommodate AAD */ n_sgin = n_sgin + 1; @@ -1460,8 +1478,10 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout), sk->sk_allocation); - if (!mem) - return -ENOMEM; + if (!mem) { + err = -ENOMEM; + goto exit_free_skb; + } /* Segment the allocated memory */ aead_req = (struct aead_request *)mem; @@ -1510,88 +1530,107 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, if (err < 0) goto exit_free; - if (n_sgout) { - if (out_iov) { - sg_init_table(sgout, n_sgout); - sg_set_buf(&sgout[0], dctx->aad, prot->aad_size); + if (clear_skb) { + sg_init_table(sgout, n_sgout); + sg_set_buf(&sgout[0], dctx->aad, prot->aad_size); - err = tls_setup_from_iter(out_iov, data_len, - &pages, &sgout[1], - (n_sgout - 1 - tail_pages)); - if (err < 0) - goto fallback_to_reg_recv; + err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size, + data_len + prot->tail_size); + if (err < 0) + goto exit_free; + } else if (out_iov) { + sg_init_table(sgout, n_sgout); + sg_set_buf(&sgout[0], dctx->aad, prot->aad_size); - if (prot->tail_size) { - sg_unmark_end(&sgout[pages]); - sg_set_buf(&sgout[pages + 1], &dctx->tail, - prot->tail_size); - sg_mark_end(&sgout[pages + 1]); - } - } else if (out_sg) { - memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); - } else { - goto fallback_to_reg_recv; + err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1], + (n_sgout - 1 - tail_pages)); + if (err < 0) + goto exit_free_pages; + + if (prot->tail_size) { + sg_unmark_end(&sgout[pages]); + sg_set_buf(&sgout[pages + 1], &dctx->tail, + prot->tail_size); + sg_mark_end(&sgout[pages + 1]); } - } else { -fallback_to_reg_recv: - sgout = sgin; - pages = 0; - darg->zc = false; + } else if (out_sg) { + memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); } /* Prepare and submit AEAD request */ - err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv, + err = tls_do_decryption(sk, sgin, sgout, dctx->iv, data_len + prot->tail_size, aead_req, darg); - if (darg->async) - return 0; + if (err) + goto exit_free_pages; + + darg->skb = clear_skb ?: tls_strp_msg(ctx); + clear_skb = NULL; + + if (unlikely(darg->async)) { + err = tls_strp_msg_hold(sk, skb, &ctx->async_hold); + if (err) + __skb_queue_tail(&ctx->async_hold, darg->skb); + return err; + } if (prot->tail_size) darg->tail = dctx->tail; +exit_free_pages: /* Release the pages in case iov was mapped to pages */ for (; pages > 0; pages--) put_page(sg_page(&sgout[pages])); exit_free: kfree(mem); +exit_free_skb: + consume_skb(clear_skb); return err; } -static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, - struct iov_iter *dest, - struct tls_decrypt_arg *darg) +static int +tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx, + struct tls_decrypt_arg *darg) +{ + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + int err; + + if (tls_ctx->rx_conf != TLS_HW) + return 0; + + err = tls_device_decrypted(sk, tls_ctx); + if (err <= 0) + return err; + + darg->zc = false; + darg->async = false; + darg->skb = tls_strp_msg(ctx); + ctx->recv_pkt = NULL; + return 1; +} + +static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest, + struct tls_decrypt_arg *darg) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_prot_info *prot = &tls_ctx->prot_info; - struct strp_msg *rxm = strp_msg(skb); - struct tls_msg *tlm = tls_msg(skb); + struct strp_msg *rxm; int pad, err; - if (tlm->decrypted) { - darg->zc = false; - darg->async = false; - return 0; - } + err = tls_decrypt_device(sk, tls_ctx, darg); + if (err < 0) + return err; + if (err) + goto decrypt_done; - if (tls_ctx->rx_conf == TLS_HW) { - err = tls_device_decrypted(sk, tls_ctx, skb, rxm); - if (err < 0) - return err; - if (err > 0) { - tlm->decrypted = 1; - darg->zc = false; - darg->async = false; - goto decrypt_done; - } - } - - err = decrypt_internal(sk, skb, dest, NULL, darg); + err = tls_decrypt_sg(sk, dest, NULL, darg); if (err < 0) { if (err == -EBADMSG) TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); return err; } if (darg->async) - goto decrypt_next; + goto decrypt_done; /* If opportunistic TLS 1.3 ZC failed retry without ZC */ if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION && darg->tail != TLS_RECORD_TYPE_DATA)) { @@ -1599,30 +1638,33 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, if (!darg->tail) TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY); - return decrypt_skb_update(sk, skb, dest, darg); + return tls_rx_one_record(sk, dest, darg); } decrypt_done: - pad = tls_padding_length(prot, skb, darg); - if (pad < 0) - return pad; + if (darg->skb == ctx->recv_pkt) + ctx->recv_pkt = NULL; + pad = tls_padding_length(prot, darg->skb, darg); + if (pad < 0) { + consume_skb(darg->skb); + return pad; + } + + rxm = strp_msg(darg->skb); rxm->full_len -= pad; rxm->offset += prot->prepend_size; rxm->full_len -= prot->overhead_size; - tlm->decrypted = 1; -decrypt_next: tls_advance_record_sn(sk, prot, &tls_ctx->rx); return 0; } -int decrypt_skb(struct sock *sk, struct sk_buff *skb, - struct scatterlist *sgout) +int decrypt_skb(struct sock *sk, struct scatterlist *sgout) { struct tls_decrypt_arg darg = { .zc = true, }; - return decrypt_internal(sk, skb, NULL, sgout, &darg); + return tls_decrypt_sg(sk, NULL, sgout, &darg); } static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, @@ -1648,6 +1690,13 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, return 1; } +static void tls_rx_rec_done(struct tls_sw_context_rx *ctx) +{ + consume_skb(ctx->recv_pkt); + ctx->recv_pkt = NULL; + __strp_unpause(&ctx->strp); +} + /* This function traverses the rx_list in tls receive context to copies the * decrypted records into the buffer provided by caller zero copy is not * true. Further, the records are removed from the rx_list if it is not a peek @@ -1658,7 +1707,6 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, u8 *control, size_t skip, size_t len, - bool zc, bool is_peek) { struct sk_buff *skb = skb_peek(&ctx->rx_list); @@ -1692,12 +1740,10 @@ static int process_rx_list(struct tls_sw_context_rx *ctx, if (err <= 0) goto out; - if (!zc || (rxm->full_len - skip) > len) { - err = skb_copy_datagram_msg(skb, rxm->offset + skip, - msg, chunk); - if (err < 0) - goto out; - } + err = skb_copy_datagram_msg(skb, rxm->offset + skip, + msg, chunk); + if (err < 0) + goto out; len = len - chunk; copied = copied + chunk; @@ -1753,6 +1799,51 @@ tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, sk_flush_backlog(sk); } +static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, + bool nonblock) +{ + long timeo; + + lock_sock(sk); + + timeo = sock_rcvtimeo(sk, nonblock); + + while (unlikely(ctx->reader_present)) { + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + ctx->reader_contended = 1; + + add_wait_queue(&ctx->wq, &wait); + sk_wait_event(sk, &timeo, + !READ_ONCE(ctx->reader_present), &wait); + remove_wait_queue(&ctx->wq, &wait); + + if (!timeo) + return -EAGAIN; + if (signal_pending(current)) + return sock_intr_errno(timeo); + } + + WRITE_ONCE(ctx->reader_present, 1); + + return timeo; +} + +static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx) +{ + if (unlikely(ctx->reader_contended)) { + if (wq_has_sleeper(&ctx->wq)) + wake_up(&ctx->wq); + else + ctx->reader_contended = 0; + + WARN_ON_ONCE(!ctx->reader_present); + } + + WRITE_ONCE(ctx->reader_present, 0); + release_sock(sk); +} + int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, @@ -1762,9 +1853,9 @@ int tls_sw_recvmsg(struct sock *sk, struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_prot_info *prot = &tls_ctx->prot_info; + ssize_t decrypted = 0, async_copy_bytes = 0; struct sk_psock *psock; unsigned char control = 0; - ssize_t decrypted = 0; size_t flushed_at = 0; struct strp_msg *rxm; struct tls_msg *tlm; @@ -1782,7 +1873,9 @@ int tls_sw_recvmsg(struct sock *sk, return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); psock = sk_psock_get(sk); - lock_sock(sk); + timeo = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT); + if (timeo < 0) + return timeo; bpf_strp_enabled = sk_psock_strp_enabled(psock); /* If crypto failed the connection is broken */ @@ -1791,7 +1884,7 @@ int tls_sw_recvmsg(struct sock *sk, goto end; /* Process pending decrypted records. It must be non-zero-copy */ - err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek); + err = process_rx_list(ctx, msg, &control, 0, len, is_peek); if (err < 0) goto end; @@ -1801,13 +1894,12 @@ int tls_sw_recvmsg(struct sock *sk, target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); len = len - copied; - timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek && ctx->zc_capable; decrypted = 0; while (len && (decrypted + copied < target || ctx->recv_pkt)) { - struct tls_decrypt_arg darg = {}; + struct tls_decrypt_arg darg; int to_decrypt, chunk; err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo); @@ -1815,15 +1907,19 @@ int tls_sw_recvmsg(struct sock *sk, if (psock) { chunk = sk_msg_recvmsg(sk, psock, msg, len, flags); - if (chunk > 0) - goto leave_on_list; + if (chunk > 0) { + decrypted += chunk; + len -= chunk; + continue; + } } goto recv_end; } - skb = ctx->recv_pkt; - rxm = strp_msg(skb); - tlm = tls_msg(skb); + memset(&darg.inargs, 0, sizeof(darg.inargs)); + + rxm = strp_msg(ctx->recv_pkt); + tlm = tls_msg(ctx->recv_pkt); to_decrypt = rxm->full_len - prot->overhead_size; @@ -1837,12 +1933,16 @@ int tls_sw_recvmsg(struct sock *sk, else darg.async = false; - err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg); + err = tls_rx_one_record(sk, &msg->msg_iter, &darg); if (err < 0) { tls_err_abort(sk, -EBADMSG); goto recv_end; } + skb = darg.skb; + rxm = strp_msg(skb); + tlm = tls_msg(skb); + async |= darg.async; /* If the type of records being processed is not known yet, @@ -1853,34 +1953,36 @@ int tls_sw_recvmsg(struct sock *sk, * For tls1.3, we disable async. */ err = tls_record_content_type(msg, tlm, &control); - if (err <= 0) + if (err <= 0) { + tls_rx_rec_done(ctx); +put_on_rx_list_err: + __skb_queue_tail(&ctx->rx_list, skb); goto recv_end; + } /* periodically flush backlog, and feed strparser */ tls_read_flush_backlog(sk, prot, len, to_decrypt, decrypted + copied, &flushed_at); - ctx->recv_pkt = NULL; - __strp_unpause(&ctx->strp); - __skb_queue_tail(&ctx->rx_list, skb); - - if (async) { - /* TLS 1.2-only, to_decrypt must be text length */ - chunk = min_t(int, to_decrypt, len); -leave_on_list: - decrypted += chunk; - len -= chunk; - continue; - } /* TLS 1.3 may have updated the length by more than overhead */ chunk = rxm->full_len; + tls_rx_rec_done(ctx); if (!darg.zc) { bool partially_consumed = chunk > len; + if (async) { + /* TLS 1.2-only, to_decrypt must be text len */ + chunk = min_t(int, to_decrypt, len); + async_copy_bytes += chunk; +put_on_rx_list: + decrypted += chunk; + len -= chunk; + __skb_queue_tail(&ctx->rx_list, skb); + continue; + } + if (bpf_strp_enabled) { - /* BPF may try to queue the skb */ - __skb_unlink(skb, &ctx->rx_list); err = sk_psock_tls_strp_read(psock, skb); if (err != __SK_PASS) { rxm->offset = rxm->offset + rxm->full_len; @@ -1889,7 +1991,6 @@ leave_on_list: consume_skb(skb); continue; } - __skb_queue_tail(&ctx->rx_list, skb); } if (partially_consumed) @@ -1898,22 +1999,21 @@ leave_on_list: err = skb_copy_datagram_msg(skb, rxm->offset, msg, chunk); if (err < 0) - goto recv_end; + goto put_on_rx_list_err; if (is_peek) - goto leave_on_list; + goto put_on_rx_list; if (partially_consumed) { rxm->offset += chunk; rxm->full_len -= chunk; - goto leave_on_list; + goto put_on_rx_list; } } decrypted += chunk; len -= chunk; - __skb_unlink(skb, &ctx->rx_list); consume_skb(skb); /* Return full control message to userspace before trying @@ -1933,30 +2033,32 @@ recv_end: reinit_completion(&ctx->async_wait.completion); pending = atomic_read(&ctx->decrypt_pending); spin_unlock_bh(&ctx->decrypt_compl_lock); - if (pending) { + ret = 0; + if (pending) ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); - if (ret) { - if (err >= 0 || err == -EINPROGRESS) - err = ret; - decrypted = 0; - goto end; - } + __skb_queue_purge(&ctx->async_hold); + + if (ret) { + if (err >= 0 || err == -EINPROGRESS) + err = ret; + decrypted = 0; + goto end; } /* Drain records from the rx_list & copy if required */ if (is_peek || is_kvec) err = process_rx_list(ctx, msg, &control, copied, - decrypted, false, is_peek); + decrypted, is_peek); else err = process_rx_list(ctx, msg, &control, 0, - decrypted, true, is_peek); + async_copy_bytes, is_peek); decrypted = max(err, 0); } copied += decrypted; end: - release_sock(sk); + tls_rx_reader_unlock(sk, ctx); if (psock) sk_psock_put(sk, psock); return copied ? : err; @@ -1973,33 +2075,34 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, struct tls_msg *tlm; struct sk_buff *skb; ssize_t copied = 0; - bool from_queue; int err = 0; long timeo; int chunk; - lock_sock(sk); + timeo = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK); + if (timeo < 0) + return timeo; - timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK); - - from_queue = !skb_queue_empty(&ctx->rx_list); - if (from_queue) { + if (!skb_queue_empty(&ctx->rx_list)) { skb = __skb_dequeue(&ctx->rx_list); } else { - struct tls_decrypt_arg darg = {}; + struct tls_decrypt_arg darg; err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo); if (err <= 0) goto splice_read_end; - skb = ctx->recv_pkt; + memset(&darg.inargs, 0, sizeof(darg.inargs)); - err = decrypt_skb_update(sk, skb, NULL, &darg); + err = tls_rx_one_record(sk, NULL, &darg); if (err < 0) { tls_err_abort(sk, -EBADMSG); goto splice_read_end; } + + tls_rx_rec_done(ctx); + skb = darg.skb; } rxm = strp_msg(skb); @@ -2008,29 +2111,29 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, /* splice does not support reading control messages */ if (tlm->control != TLS_RECORD_TYPE_DATA) { err = -EINVAL; - goto splice_read_end; + goto splice_requeue; } chunk = min_t(unsigned int, rxm->full_len, len); copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); if (copied < 0) - goto splice_read_end; + goto splice_requeue; - if (!from_queue) { - ctx->recv_pkt = NULL; - __strp_unpause(&ctx->strp); - } if (chunk < rxm->full_len) { - __skb_queue_head(&ctx->rx_list, skb); rxm->offset += len; rxm->full_len -= len; - } else { - consume_skb(skb); + goto splice_requeue; } + consume_skb(skb); + splice_read_end: - release_sock(sk); + tls_rx_reader_unlock(sk, ctx); return copied ? : err; + +splice_requeue: + __skb_queue_head(&ctx->rx_list, skb); + goto splice_read_end; } bool tls_sw_sock_is_readable(struct sock *sk) @@ -2076,7 +2179,6 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) if (ret < 0) goto read_failure; - tlm->decrypted = 0; tlm->control = header[0]; data_len = ((header[4] & 0xFF) | (header[3] << 8)); @@ -2371,9 +2473,11 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) } else { crypto_init_wait(&sw_ctx_rx->async_wait); spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); + init_waitqueue_head(&sw_ctx_rx->wq); crypto_info = &ctx->crypto_recv.info; cctx = &ctx->rx; skb_queue_head_init(&sw_ctx_rx->rx_list); + skb_queue_head_init(&sw_ctx_rx->async_hold); aead = &sw_ctx_rx->aead_recv; }