net/tls: use RCU protection on icsk->icsk_ulp_data

We need to make sure context does not get freed while diag
code is interrogating it. Free struct tls_context with
kfree_rcu().

We add the __rcu annotation directly in icsk, and cast it
away in the datapath accessor. Presumably all ULPs will
do a similar thing.

Signed-off-by: Jakub Kicinski <jakub.kicinski@netronome.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
Jakub Kicinski 2019-08-30 12:25:47 +02:00 committed by David S. Miller
parent ed6e8103ba
commit 15a7dea750
5 changed files with 29 additions and 12 deletions

View File

@ -97,7 +97,7 @@ struct inet_connection_sock {
const struct tcp_congestion_ops *icsk_ca_ops;
const struct inet_connection_sock_af_ops *icsk_af_ops;
const struct tcp_ulp_ops *icsk_ulp_ops;
void *icsk_ulp_data;
void __rcu *icsk_ulp_data;
void (*icsk_clean_acked)(struct sock *sk, u32 acked_seq);
struct hlist_node icsk_listen_portaddr_node;
unsigned int (*icsk_sync_mss)(struct sock *sk, u32 pmtu);

View File

@ -41,6 +41,7 @@
#include <linux/tcp.h>
#include <linux/skmsg.h>
#include <linux/netdevice.h>
#include <linux/rcupdate.h>
#include <net/tcp.h>
#include <net/strparser.h>
@ -290,6 +291,7 @@ struct tls_context {
struct list_head list;
refcount_t refcount;
struct rcu_head rcu;
};
enum tls_offload_ctx_dir {
@ -348,7 +350,7 @@ struct tls_offload_context_rx {
#define TLS_OFFLOAD_CONTEXT_SIZE_RX \
(sizeof(struct tls_offload_context_rx) + TLS_DRIVER_STATE_SIZE_RX)
void tls_ctx_free(struct tls_context *ctx);
void tls_ctx_free(struct sock *sk, struct tls_context *ctx);
int wait_on_pending_writer(struct sock *sk, long *timeo);
int tls_sk_query(struct sock *sk, int optname, char __user *optval,
int __user *optlen);
@ -467,7 +469,10 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk)
{
struct inet_connection_sock *icsk = inet_csk(sk);
return icsk->icsk_ulp_data;
/* Use RCU on icsk_ulp_data only for sock diag code,
* TLS data path doesn't need rcu_dereference().
*/
return (__force void *)icsk->icsk_ulp_data;
}
static inline void tls_advance_record_sn(struct sock *sk,

View File

@ -345,7 +345,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
return -EINVAL;
if (unlikely(idx >= map->max_entries))
return -E2BIG;
if (unlikely(icsk->icsk_ulp_data))
if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data)))
return -EINVAL;
link = sk_psock_init_link();

View File

@ -61,7 +61,7 @@ static void tls_device_free_ctx(struct tls_context *ctx)
if (ctx->rx_conf == TLS_HW)
kfree(tls_offload_ctx_rx(ctx));
tls_ctx_free(ctx);
tls_ctx_free(NULL, ctx);
}
static void tls_device_gc_task(struct work_struct *work)

View File

@ -251,14 +251,26 @@ static void tls_write_space(struct sock *sk)
ctx->sk_write_space(sk);
}
void tls_ctx_free(struct tls_context *ctx)
/**
* tls_ctx_free() - free TLS ULP context
* @sk: socket to with @ctx is attached
* @ctx: TLS context structure
*
* Free TLS context. If @sk is %NULL caller guarantees that the socket
* to which @ctx was attached has no outstanding references.
*/
void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
{
if (!ctx)
return;
memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
kfree(ctx);
if (sk)
kfree_rcu(ctx, rcu);
else
kfree(ctx);
}
static void tls_sk_proto_cleanup(struct sock *sk,
@ -306,7 +318,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
write_lock_bh(&sk->sk_callback_lock);
if (free_ctx)
icsk->icsk_ulp_data = NULL;
rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
sk->sk_prot = ctx->sk_proto;
if (sk->sk_write_space == tls_write_space)
sk->sk_write_space = ctx->sk_write_space;
@ -321,7 +333,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
ctx->sk_proto_close(sk, timeout);
if (free_ctx)
tls_ctx_free(ctx);
tls_ctx_free(sk, ctx);
}
static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
@ -610,7 +622,7 @@ static struct tls_context *create_ctx(struct sock *sk)
if (!ctx)
return NULL;
icsk->icsk_ulp_data = ctx;
rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
ctx->setsockopt = sk->sk_prot->setsockopt;
ctx->getsockopt = sk->sk_prot->getsockopt;
ctx->sk_proto_close = sk->sk_prot->close;
@ -651,8 +663,8 @@ static void tls_hw_sk_destruct(struct sock *sk)
ctx->sk_destruct(sk);
/* Free ctx */
tls_ctx_free(ctx);
icsk->icsk_ulp_data = NULL;
rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
tls_ctx_free(sk, ctx);
}
static int tls_hw_prot(struct sock *sk)