From e1ac39e67ff29fc9dd40f361c249656d416ab7d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adolfo=20G=C3=B3mez=20Garc=C3=ADa?= Date: Sat, 19 Oct 2024 22:35:20 +0200 Subject: [PATCH] Refactor storage to use dictionary views for improved performance and convenience --- .../src/uds/core/types/deferred_deletion.py | 2 +- server/src/uds/core/util/storage.py | 97 +++++++++++++++---- server/src/uds/models/service_pool.py | 6 +- .../generics/test_dynamic_deferred_delete.py | 11 ++- server/tests/core/util/test_storage.py | 34 ++++++- 5 files changed, 124 insertions(+), 26 deletions(-) diff --git a/server/src/uds/core/types/deferred_deletion.py b/server/src/uds/core/types/deferred_deletion.py index 53984087d..932ae58d4 100644 --- a/server/src/uds/core/types/deferred_deletion.py +++ b/server/src/uds/core/types/deferred_deletion.py @@ -132,7 +132,7 @@ class DeletionInfo: now = sql_now() with DeletionInfo.deferred_storage.as_dict(group, atomic=True) as storage_dict: for key, info in sorted( - typing.cast(collections.abc.Iterable[tuple[str, DeletionInfo]], storage_dict.items()), + typing.cast(collections.abc.Iterable[tuple[str, DeletionInfo]], storage_dict.unlocked_items()), key=lambda x: x[1].next_check, ): # if max retries reached, remove it diff --git a/server/src/uds/core/util/storage.py b/server/src/uds/core/util/storage.py index e8289fe71..02350f6a4 100644 --- a/server/src/uds/core/util/storage.py +++ b/server/src/uds/core/util/storage.py @@ -81,11 +81,56 @@ def _decode_value(dbk: str, value: typing.Optional[str]) -> tuple[str, typing.An return ('', None) -class StorageAsDict(dict[str, typing.Any]): +class StorageAsDict(collections.abc.MutableMapping[str, typing.Any]): """ Accesses storage as dictionary. Much more convenient that old method """ + class DBBasedItemsView(collections.abc.ItemsView[str, typing.Any]): + _storage: 'StorageAsDict' + + def __init__(self, storage: 'StorageAsDict'): + self._storage = storage + + def __contains__(self, key: object) -> bool: + return key in self._storage + + def __iter__(self) -> typing.Iterator[tuple[str, typing.Any]]: + return self._storage.items_iter() + + def __len__(self) -> int: + return len(self._storage) + + class DBBasedKeysView(collections.abc.KeysView[str]): + _storage: 'StorageAsDict' + + def __init__(self, storage: 'StorageAsDict'): + self._storage = storage + + def __contains__(self, key: object) -> bool: + return key in self._storage + + def __iter__(self) -> typing.Iterator[str]: + return self._storage.keys_iter() + + def __len__(self) -> int: + return len(self._storage) + + class DBBasedValuesView(collections.abc.ValuesView[typing.Any]): + _storage: 'StorageAsDict' + + def __init__(self, storage: 'StorageAsDict'): + self._storage = storage + + def __contains__(self, value: object) -> bool: + return value in self._storage.values_iter() + + def __iter__(self) -> typing.Iterator[typing.Any]: + return self._storage.values_iter() + + def __len__(self) -> int: + return len(self._storage) + _group: str _owner: str _atomic: bool @@ -109,20 +154,24 @@ class StorageAsDict(dict[str, typing.Any]): self._owner = owner self._atomic = atomic # Not used right now, maybe removed - @property - def _db(self) -> typing.Union[models.QuerySet[DBStorage], models.Manager[DBStorage]]: + def _db( + self, *, skip_locked: bool = False + ) -> typing.Union[models.QuerySet[DBStorage], models.Manager[DBStorage]]: if self._atomic: - # TODO: add skip_locked ASAP (mariadb 10.6+) + if skip_locked: + # TODO: add skip_locked (as select_for_update(...) argument) ASAP (mariadb 10.6+) + pass return DBStorage.objects.select_for_update() return DBStorage.objects - @property - def _filtered(self) -> 'models.QuerySet[DBStorage]': + def _filtered(self, *, skip_locked: bool = False) -> 'models.QuerySet[DBStorage]': fltr_params = {'owner': self._owner} if self._group: fltr_params['attr1'] = self._group - return typing.cast('models.QuerySet[DBStorage]', self._db.filter(**fltr_params)) + return typing.cast( + 'models.QuerySet[DBStorage]', self._db(skip_locked=skip_locked).filter(**fltr_params) + ) def _key(self, key: str, old_method: bool = False) -> str: if key[0] == '#': @@ -141,7 +190,7 @@ class StorageAsDict(dict[str, typing.Any]): for use_old_method in (False, True): db_key = self._key(key, old_method=use_old_method) try: - c: DBStorage = self._db.get(pk=db_key) + c: DBStorage = self._db().get(pk=db_key) if c.owner != self._owner: # Maybe a key collision, logger.error('Key collision detected for key %s', key) return None @@ -174,25 +223,37 @@ class StorageAsDict(dict[str, typing.Any]): """ Iterates through keys """ - return iter(_decode_value(i.key, i.data)[0] for i in self._filtered) + return iter(_decode_value(i.key, i.data)[0] for i in self._filtered()) def __contains__(self, key: object) -> bool: if isinstance(key, str): - return self._filtered.filter(key=self._key(key)).exists() + return self._filtered().filter(key=self._key(key)).exists() return False def __len__(self) -> int: - return self._filtered.count() + return self._filtered().count() + + def items_iter(self) -> typing.Iterator[tuple[str, typing.Any]]: + return iter(_decode_value(i.key, i.data) for i in self._filtered()) + + def keys_iter(self) -> typing.Iterator[str]: + return iter(_decode_value(i.key, i.data)[0] for i in self._filtered()) + + def values_iter(self) -> typing.Iterator[typing.Any]: + return iter(_decode_value(i.key, i.data)[1] for i in self._filtered()) + + def unlocked_items(self) -> typing.Iterator[tuple[str, typing.Any]]: + return iter(_decode_value(i.key, i.data) for i in self._filtered(skip_locked=True)) # Optimized methods, avoid re-reading from DB - def items(self) -> typing.Iterator[tuple[str, typing.Any]]: # type: ignore # compatible type - return iter(_decode_value(i.key, i.data) for i in self._filtered) + def items(self) -> collections.abc.ItemsView[str, typing.Any]: + return StorageAsDict.DBBasedItemsView(self) - def keys(self) -> typing.Iterator[str]: # type: ignore # compatible type - return iter(_decode_value(i.key, i.data)[0] for i in self._filtered) + def keys(self) -> collections.abc.KeysView[str]: + return StorageAsDict.DBBasedKeysView(self) - def values(self) -> typing.Iterator[typing.Any]: # type: ignore # compatible type - return iter(_decode_value(i.key, i.data)[1] for i in self._filtered) + def values(self) -> collections.abc.ValuesView[typing.Any]: + return StorageAsDict.DBBasedValuesView(self) def get(self, key: str, default: typing.Any = None) -> typing.Any: return self[key] or default @@ -201,7 +262,7 @@ class StorageAsDict(dict[str, typing.Any]): self.__delitem__(key) # pylint: disable=unnecessary-dunder-call def clear(self) -> None: - self._filtered.delete() # Removes all keys + self._filtered().delete() # Removes all keys # Custom utility methods @property diff --git a/server/src/uds/models/service_pool.py b/server/src/uds/models/service_pool.py index 843332476..d898087fe 100644 --- a/server/src/uds/models/service_pool.py +++ b/server/src/uds/models/service_pool.py @@ -404,7 +404,8 @@ class ServicePool(UUIDModel, TaggingMixin): name: Name of the value to store value: Value of the value to store """ - self.get_environment().storage.put(name, value) + with self.get_environment().storage.as_dict() as storage: + storage[name] = value def get_value(self, name: str) -> typing.Any: """ @@ -416,7 +417,8 @@ class ServicePool(UUIDModel, TaggingMixin): Returns: Stored value, None if no value was stored """ - return typing.cast(str, self.get_environment().storage.read(name)) + with self.get_environment().storage.as_dict() as storage: + return storage.get(name) def set_state(self, state: str, save: bool = True) -> None: """ diff --git a/server/tests/core/services/generics/test_dynamic_deferred_delete.py b/server/tests/core/services/generics/test_dynamic_deferred_delete.py index 5c14504e7..6c80cd3b6 100644 --- a/server/tests/core/services/generics/test_dynamic_deferred_delete.py +++ b/server/tests/core/services/generics/test_dynamic_deferred_delete.py @@ -47,6 +47,10 @@ from ....utils.test import UDSTransactionTestCase from ....utils import helpers from . import fixtures +class TestDict(dict[str, deferred_types.DeletionInfo]): + def unlocked_items(self) -> typing.ItemsView[str, deferred_types.DeletionInfo]: + return self.items() + class DynamicDeferredDeleteTest(UDSTransactionTestCase): def setUp(self) -> None: @@ -74,13 +78,14 @@ class DynamicDeferredDeleteTest(UDSTransactionTestCase): is_running: typing.Union[None, typing.Callable[..., bool]] = None, must_stop_before_deletion: bool = False, should_try_soft_shutdown: bool = False, - ) -> typing.Iterator[tuple[mock.MagicMock, dict[str, dict[str, deferred_types.DeletionInfo]]]]: + ) -> typing.Iterator[tuple[mock.MagicMock, dict[str, TestDict]]]: """ Patch the storage to use a dict instead of the real storage This is useful to test the worker without touching the real storage """ - dct: dict[str, dict[str, deferred_types.DeletionInfo]] = {} + + dct: dict[str, TestDict] = {} instance = mock.MagicMock() instance_db_obj = mock.MagicMock(uuid='service1') instance_db_obj.get_instance.return_value = instance @@ -106,7 +111,7 @@ class DynamicDeferredDeleteTest(UDSTransactionTestCase): group: str, *args: typing.Any, **kwargs: typing.Any ) -> typing.Iterator[dict[str, deferred_types.DeletionInfo]]: if group not in dct: - dct[group] = {} + dct[group] = TestDict() yield dct[group] storage.as_dict.side_effect = _as_dict diff --git a/server/tests/core/util/test_storage.py b/server/tests/core/util/test_storage.py index fa543237f..7aa657645 100644 --- a/server/tests/core/util/test_storage.py +++ b/server/tests/core/util/test_storage.py @@ -84,17 +84,47 @@ class StorageTest(UDSTestCase): self.assertEqual(d[UNICODE_CHARS], 'chars') self.assertEqual(d['test_key'], UNICODE_CHARS_2) - + # Assert that UNICODE_CHARS is in the dict d['test_key2'] = 0 d['test_key2'] += 1 - + self.assertEqual(d['test_key2'], 1) # The values set inside the "with" are not available "outside" # because the format is not compatible (with the dict, the values are stored as a tuple, with the original key stored # and with old format, only the value is stored + def test_storage_as_dict_views(self) -> None: + strg = storage.Storage(UNICODE_CHARS) + + items = {'key_{i}': f'value_{i}' for i in range(32)} + + with strg.as_dict() as dct: + # Store all items + for k, v in items.items(): + dct[k] = v + + for k, v in dct.items(): + self.assertEqual(v, items[k]) + + for k in dct.keys(): + self.assertIn(k, items) + + for v in dct.values(): + self.assertIn(v, items.values()) + + self.assertEqual(len(dct), len(items)) + self.assertEqual(len(dct.items()), len(items)) + self.assertEqual(len(dct.keys()), len(items)) + self.assertEqual(len(dct.values()), len(items)) + + # Contains for items, keys and values + for k in items: + self.assertIn(k, dct) + self.assertIn(k, dct.keys()) + self.assertIn(items[k], dct.values()) + def test_old_storage_compat(self) -> None: models.Storage.objects.create( owner=UNICODE_CHARS,