Merge branch 'tls-fixes-for-record-type-handling-with-peek'

Sabrina Dubroca says:

====================
tls: fixes for record type handling with PEEK

There are multiple bugs in tls_sw_recvmsg's handling of record types
when MSG_PEEK flag is used, which can lead to incorrectly merging two
records:
 - consecutive non-DATA records shouldn't be merged, even if they're
   the same type (partly handled by the test at the end of the main
   loop)
 - records of the same type (even DATA) shouldn't be merged if one
   record of a different type comes in between
====================

Link: https://lore.kernel.org/r/cover.1708007371.git.sd@queasysnail.net
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Jakub Kicinski 2024-02-21 14:25:53 -08:00
commit f76d5f6580
2 changed files with 61 additions and 8 deletions

View File

@ -1772,7 +1772,8 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
u8 *control, u8 *control,
size_t skip, size_t skip,
size_t len, size_t len,
bool is_peek) bool is_peek,
bool *more)
{ {
struct sk_buff *skb = skb_peek(&ctx->rx_list); struct sk_buff *skb = skb_peek(&ctx->rx_list);
struct tls_msg *tlm; struct tls_msg *tlm;
@ -1785,7 +1786,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
err = tls_record_content_type(msg, tlm, control); err = tls_record_content_type(msg, tlm, control);
if (err <= 0) if (err <= 0)
goto out; goto more;
if (skip < rxm->full_len) if (skip < rxm->full_len)
break; break;
@ -1803,12 +1804,12 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
err = tls_record_content_type(msg, tlm, control); err = tls_record_content_type(msg, tlm, control);
if (err <= 0) if (err <= 0)
goto out; goto more;
err = skb_copy_datagram_msg(skb, rxm->offset + skip, err = skb_copy_datagram_msg(skb, rxm->offset + skip,
msg, chunk); msg, chunk);
if (err < 0) if (err < 0)
goto out; goto more;
len = len - chunk; len = len - chunk;
copied = copied + chunk; copied = copied + chunk;
@ -1844,6 +1845,10 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
out: out:
return copied ? : err; return copied ? : err;
more:
if (more)
*more = true;
goto out;
} }
static bool static bool
@ -1947,6 +1952,7 @@ int tls_sw_recvmsg(struct sock *sk,
int target, err; int target, err;
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
bool is_peek = flags & MSG_PEEK; bool is_peek = flags & MSG_PEEK;
bool rx_more = false;
bool released = true; bool released = true;
bool bpf_strp_enabled; bool bpf_strp_enabled;
bool zc_capable; bool zc_capable;
@ -1966,12 +1972,12 @@ int tls_sw_recvmsg(struct sock *sk,
goto end; goto end;
/* Process pending decrypted records. It must be non-zero-copy */ /* Process pending decrypted records. It must be non-zero-copy */
err = process_rx_list(ctx, msg, &control, 0, len, is_peek); err = process_rx_list(ctx, msg, &control, 0, len, is_peek, &rx_more);
if (err < 0) if (err < 0)
goto end; goto end;
copied = err; copied = err;
if (len <= copied) if (len <= copied || (copied && control != TLS_RECORD_TYPE_DATA) || rx_more)
goto end; goto end;
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
@ -2064,6 +2070,8 @@ put_on_rx_list:
decrypted += chunk; decrypted += chunk;
len -= chunk; len -= chunk;
__skb_queue_tail(&ctx->rx_list, skb); __skb_queue_tail(&ctx->rx_list, skb);
if (unlikely(control != TLS_RECORD_TYPE_DATA))
break;
continue; continue;
} }
@ -2128,10 +2136,10 @@ recv_end:
/* Drain records from the rx_list & copy if required */ /* Drain records from the rx_list & copy if required */
if (is_peek || is_kvec) if (is_peek || is_kvec)
err = process_rx_list(ctx, msg, &control, copied, err = process_rx_list(ctx, msg, &control, copied,
decrypted, is_peek); decrypted, is_peek, NULL);
else else
err = process_rx_list(ctx, msg, &control, 0, err = process_rx_list(ctx, msg, &control, 0,
async_copy_bytes, is_peek); async_copy_bytes, is_peek, NULL);
} }
copied += decrypted; copied += decrypted;

View File

@ -1485,6 +1485,51 @@ TEST_F(tls, control_msg)
EXPECT_EQ(memcmp(buf, test_str, send_len), 0); EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
} }
TEST_F(tls, control_msg_nomerge)
{
char *rec1 = "1111";
char *rec2 = "2222";
int send_len = 5;
char buf[15];
if (self->notls)
SKIP(return, "no TLS support");
EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec1, send_len, 0), send_len);
EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), MSG_PEEK), send_len);
EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
EXPECT_EQ(memcmp(buf, rec1, send_len), 0);
EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, 100, buf, sizeof(buf), 0), send_len);
EXPECT_EQ(memcmp(buf, rec2, send_len), 0);
}
TEST_F(tls, data_control_data)
{
char *rec1 = "1111";
char *rec2 = "2222";
char *rec3 = "3333";
int send_len = 5;
char buf[15];
if (self->notls)
SKIP(return, "no TLS support");
EXPECT_EQ(send(self->fd, rec1, send_len, 0), send_len);
EXPECT_EQ(tls_send_cmsg(self->fd, 100, rec2, send_len, 0), send_len);
EXPECT_EQ(send(self->fd, rec3, send_len, 0), send_len);
EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_PEEK), send_len);
}
TEST_F(tls, shutdown) TEST_F(tls, shutdown)
{ {
char const *test_str = "test_read"; char const *test_str = "test_read";