diff --git a/src/basic/hexdecoct.c b/src/basic/hexdecoct.c index dc3b948d8e0..898ed83f862 100644 --- a/src/basic/hexdecoct.c +++ b/src/basic/hexdecoct.c @@ -110,12 +110,17 @@ static int unhex_next(const char **p, size_t *l) { return r; } -int unhexmem_full(const char *p, size_t l, bool secure, void **ret, size_t *ret_len) { +int unhexmem_full( + const char *p, + size_t l, + bool secure, + void **ret, + size_t *ret_len) { + _cleanup_free_ uint8_t *buf = NULL; size_t buf_size; const char *x; uint8_t *z; - int r; assert(p || l == 0); @@ -128,22 +133,20 @@ int unhexmem_full(const char *p, size_t l, bool secure, void **ret, size_t *ret_ if (!buf) return -ENOMEM; + CLEANUP_ERASE_PTR(secure ? &buf : NULL, buf_size); + for (x = p, z = buf;;) { int a, b; a = unhex_next(&x, &l); if (a == -EPIPE) /* End of string */ break; - if (a < 0) { - r = a; - goto on_failure; - } + if (a < 0) + return a; b = unhex_next(&x, &l); - if (b < 0) { - r = b; - goto on_failure; - } + if (b < 0) + return b; *(z++) = (uint8_t) a << 4 | (uint8_t) b; } @@ -156,12 +159,6 @@ int unhexmem_full(const char *p, size_t l, bool secure, void **ret, size_t *ret_ *ret = TAKE_PTR(buf); return 0; - -on_failure: - if (secure) - explicit_bzero_safe(buf, buf_size); - - return r; } /* https://tools.ietf.org/html/rfc4648#section-6 @@ -765,12 +762,17 @@ static int unbase64_next(const char **p, size_t *l) { return ret; } -int unbase64mem_full(const char *p, size_t l, bool secure, void **ret, size_t *ret_size) { +int unbase64mem_full( + const char *p, + size_t l, + bool secure, + void **ret, + size_t *ret_size) { + _cleanup_free_ uint8_t *buf = NULL; const char *x; uint8_t *z; size_t len; - int r; assert(p || l == 0); @@ -785,60 +787,44 @@ int unbase64mem_full(const char *p, size_t l, bool secure, void **ret, size_t *r if (!buf) return -ENOMEM; + CLEANUP_ERASE_PTR(secure ? &buf : NULL, len); + for (x = p, z = buf;;) { int a, b, c, d; /* a == 00XXXXXX; b == 00YYYYYY; c == 00ZZZZZZ; d == 00WWWWWW */ a = unbase64_next(&x, &l); if (a == -EPIPE) /* End of string */ break; - if (a < 0) { - r = a; - goto on_failure; - } - if (a == INT_MAX) { /* Padding is not allowed at the beginning of a 4ch block */ - r = -EINVAL; - goto on_failure; - } + if (a < 0) + return a; + if (a == INT_MAX) /* Padding is not allowed at the beginning of a 4ch block */ + return -EINVAL; b = unbase64_next(&x, &l); - if (b < 0) { - r = b; - goto on_failure; - } - if (b == INT_MAX) { /* Padding is not allowed at the second character of a 4ch block either */ - r = -EINVAL; - goto on_failure; - } + if (b < 0) + return b; + if (b == INT_MAX) /* Padding is not allowed at the second character of a 4ch block either */ + return -EINVAL; c = unbase64_next(&x, &l); - if (c < 0) { - r = c; - goto on_failure; - } + if (c < 0) + return c; d = unbase64_next(&x, &l); - if (d < 0) { - r = d; - goto on_failure; - } + if (d < 0) + return d; if (c == INT_MAX) { /* Padding at the third character */ - if (d != INT_MAX) { /* If the third character is padding, the fourth must be too */ - r = -EINVAL; - goto on_failure; - } + if (d != INT_MAX) /* If the third character is padding, the fourth must be too */ + return -EINVAL; /* b == 00YY0000 */ - if (b & 15) { - r = -EINVAL; - goto on_failure; - } + if (b & 15) + return -EINVAL; - if (l > 0) { /* Trailing rubbish? */ - r = -ENAMETOOLONG; - goto on_failure; - } + if (l > 0) /* Trailing rubbish? */ + return -ENAMETOOLONG; *(z++) = (uint8_t) a << 2 | (uint8_t) (b >> 4); /* XXXXXXYY */ break; @@ -846,15 +832,11 @@ int unbase64mem_full(const char *p, size_t l, bool secure, void **ret, size_t *r if (d == INT_MAX) { /* c == 00ZZZZ00 */ - if (c & 3) { - r = -EINVAL; - goto on_failure; - } + if (c & 3) + return -EINVAL; - if (l > 0) { /* Trailing rubbish? */ - r = -ENAMETOOLONG; - goto on_failure; - } + if (l > 0) /* Trailing rubbish? */ + return -ENAMETOOLONG; *(z++) = (uint8_t) a << 2 | (uint8_t) b >> 4; /* XXXXXXYY */ *(z++) = (uint8_t) b << 4 | (uint8_t) c >> 2; /* YYYYZZZZ */ @@ -868,18 +850,14 @@ int unbase64mem_full(const char *p, size_t l, bool secure, void **ret, size_t *r *z = 0; + assert((size_t) (z - buf) <= len); + if (ret_size) *ret_size = (size_t) (z - buf); if (ret) *ret = TAKE_PTR(buf); return 0; - -on_failure: - if (secure) - explicit_bzero_safe(buf, len); - - return r; } void hexdump(FILE *f, const void *p, size_t s) { diff --git a/src/fundamental/memory-util-fundamental.h b/src/fundamental/memory-util-fundamental.h index 67621fdb424..78e2dbec598 100644 --- a/src/fundamental/memory-util-fundamental.h +++ b/src/fundamental/memory-util-fundamental.h @@ -29,6 +29,8 @@ static inline void *explicit_bzero_safe(void *p, size_t l) { #endif struct VarEraser { + /* NB: This is a pointer to memory to erase in case of CLEANUP_ERASE(). Pointer to pointer to memory + * to erase in case of CLEANUP_ERASE_PTR() */ void *p; size_t size; }; @@ -38,5 +40,27 @@ static inline void erase_var(struct VarEraser *e) { } /* Mark var to be erased when leaving scope. */ -#define CLEANUP_ERASE(var) \ - _cleanup_(erase_var) _unused_ struct VarEraser CONCATENATE(_eraser_, UNIQ) = { .p = &var, .size = sizeof(var) } +#define CLEANUP_ERASE(var) \ + _cleanup_(erase_var) _unused_ struct VarEraser CONCATENATE(_eraser_, UNIQ) = { \ + .p = &(var), \ + .size = sizeof(var), \ + } + +static inline void erase_varp(struct VarEraser *e) { + + /* Very similar to erase_var(), but assumes `p` is a pointer to a pointer whose memory shall be destructed. */ + if (!e->p) + return; + + explicit_bzero_safe(*(void**) e->p, e->size); +} + +/* Mark pointer so that memory pointed to is erased when leaving scope. Note: this takes a pointer to the + * specified pointer, instead of just a copy of it. This is to allow callers to invalidate the pointer after + * use, if they like, disabling our automatic erasure (for example because they succeeded with whatever they + * wanted to do and now intend to return the allocated buffer to their caller without it being erased). */ +#define CLEANUP_ERASE_PTR(ptr, sz) \ + _cleanup_(erase_varp) _unused_ struct VarEraser CONCATENATE(_eraser_, UNIQ) = { \ + .p = (ptr), \ + .size = (sz), \ + } diff --git a/src/test/test-hexdecoct.c b/src/test/test-hexdecoct.c index afdc3b54368..9d71db6ae19 100644 --- a/src/test/test-hexdecoct.c +++ b/src/test/test-hexdecoct.c @@ -322,6 +322,13 @@ TEST(base64mem_linebreak) { assert_se(decoded_size == n); assert_se(memcmp(data, decoded, n) == 0); + /* Also try in secure mode */ + decoded = mfree(decoded); + decoded_size = 0; + assert_se(unbase64mem_full(encoded, SIZE_MAX, /* secure= */ true, &decoded, &decoded_size) >= 0); + assert_se(decoded_size == n); + assert_se(memcmp(data, decoded, n) == 0); + for (size_t j = 0; j < (size_t) l; j++) assert_se((encoded[j] == '\n') == (j % (m + 1) == m)); } @@ -446,7 +453,17 @@ static void test_unbase64mem_one(const char *input, const char *output, int ret) size_t size = 0; assert_se(unbase64mem(input, SIZE_MAX, &buffer, &size) == ret); + if (ret >= 0) { + assert_se(size == strlen(output)); + assert_se(memcmp(buffer, output, size) == 0); + assert_se(((char*) buffer)[size] == 0); + } + /* also try in secure mode */ + buffer = mfree(buffer); + size = 0; + + assert_se(unbase64mem_full(input, SIZE_MAX, /* secure=*/ true, &buffer, &size) == ret); if (ret >= 0) { assert_se(size == strlen(output)); assert_se(memcmp(buffer, output, size) == 0);