diff --git a/server/src/tests/core/util/test_auto_serializable.py b/server/src/tests/core/util/test_auto_serializable.py index afda92e56..2a5e20e14 100644 --- a/server/src/tests/core/util/test_auto_serializable.py +++ b/server/src/tests/core/util/test_auto_serializable.py @@ -343,22 +343,69 @@ class AutoSerializable(UDSTestCase): def test_autoserializable_dirty(self) -> None: instance = AutoSerializableClass() - self.assertFalse(instance.is_dirty()) - + self.assertFalse(instance._dirty) + + # Test list field dirty flag + self.assertEqual(instance.list_field[0], 1) + # First access sets default value, so it's dirty + self.assertTrue(instance._dirty) + + instance._dirty = False + self.assertEqual(instance.list_field[1], 2) + # Second access to ANY value does not set dirty flag because + self.assertFalse(instance._dirty) + + instance.list_field = [3, 5, 7] + self.assertTrue(instance._dirty) + + instance._dirty = False + instance.list_field[0] = 1 + self.assertTrue(instance._dirty) + + instance._dirty = False + instance.list_field.append(9) + self.assertTrue(instance._dirty) + + instance._dirty = False + instance.list_field.remove(5) + self.assertTrue(instance._dirty) + + instance._dirty = False + instance.list_field.pop() + self.assertTrue(instance._dirty) + + instance._dirty = False + instance.list_field.insert(1, 4) + self.assertTrue(instance._dirty) + + instance._dirty = False + instance.list_field.clear() + self.assertTrue(instance._dirty) + + instance.list_field = [1, 2, 3] + instance._dirty = False + del instance.list_field[1] + self.assertTrue(instance._dirty) + + instance._dirty = False + instance.list_field.extend([4, 5]) + self.assertTrue(instance._dirty) + + instance._dirty = False instance.int_field = 1 - self.assertTrue(instance.is_dirty()) + self.assertTrue(instance._dirty) instance.marshal() # should reset dirty flag - self.assertFalse(instance.is_dirty()) + self.assertFalse(instance._dirty) instance.int_field = 1 - self.assertTrue(instance.is_dirty()) + self.assertTrue(instance._dirty) instance2 = AutoSerializableClass() - self.assertFalse(instance2.is_dirty()) + self.assertFalse(instance2._dirty) instance2.int_field = 22 - self.assertTrue(instance2.is_dirty()) + self.assertTrue(instance2._dirty) instance2.unmarshal(instance.marshal()) - self.assertFalse(instance2.is_dirty()) + self.assertFalse(instance2._dirty) diff --git a/server/src/uds/core/util/autoserializable.py b/server/src/uds/core/util/autoserializable.py index 024575d56..19e4d5d19 100644 --- a/server/src/uds/core/util/autoserializable.py +++ b/server/src/uds/core/util/autoserializable.py @@ -68,7 +68,7 @@ class _Unassigned: # means field has no default value -UNASSIGNED = _Unassigned() +UNASSIGNED: typing.Final[_Unassigned] = _Unassigned() T = typing.TypeVar('T') V = typing.TypeVar('V') @@ -95,6 +95,59 @@ VERSION_SIZE: typing.Final[int] = 2 # 2 bytes for version PACKED_LENGHS: typing.Final[struct.Struct] = struct.Struct(' None: + self._owner._dirty = True + super().__setitem__(key, value) # type: ignore + + def append(self, object: T, /) -> None: + self._owner._dirty = True + super().append(object) + + def extend(self, iterable: collections.abc.Iterable[T], /) -> None: + self._owner._dirty = True + super().extend(iterable) + + def clear(self) -> None: + self._owner._dirty = True + super().clear() + + def pop(self, index: typing.SupportsIndex = -1, /) -> T: + self._owner._dirty = True + return super().pop(index) + + def insert(self, index: typing.SupportsIndex, object: T, /) -> None: + self._owner._dirty = True + super().insert(index, object) + + def remove(self, value: T, /) -> None: + self._owner._dirty = True + super().remove(value) + + def sort(self, *args: typing.Any, **kwargs: typing.Any) -> None: + self._owner._dirty = True + super().sort(*args, **kwargs) + + def __delitem__(self, key: typing.SupportsIndex | slice, /) -> None: + self._owner._dirty = True + super().__delitem__(key) + + def __iadd__(self, value: collections.abc.Iterable[T], /) -> typing.Self: + self._owner._dirty = True + return super().__iadd__(value) + + def __imul__(self, value: typing.SupportsIndex, /) -> typing.Self: + self._owner._dirty = True + return super().__imul__(value) + + # Helper functions def fernet_key(crypt_key: bytes) -> str: """Generate fermet key a crypt key @@ -212,38 +265,40 @@ class _SerializableField(typing.Generic[T]): instance {SerializableFields} -- Instance of class with field """ - if hasattr(instance, '_fields'): - if self.name in getattr(instance, '_fields'): - return getattr(instance, '_fields')[self.name] + if self.name not in instance._fields: + # Set default using setter + self.__set__(instance, self._default()) - if self.default is None: - raise AttributeError(f"Field {self.name} is not set") - # Set default using setter - self.__set__(instance, self._default()) - return getattr(instance, '_fields')[self.name] + return instance._fields[self.name] def __set__(self, instance: 'AutoSerializable', value: T) -> None: # If type is float and value is int, convert it # Or if type is int and value is float, convert it - if typing.cast(typing.Type[typing.Any], self.obj_type) in (float, int) and isinstance( - value, (float, int) - ): - value = self.obj_type(value) # type: ignore + # if self.obj_type == int and isinstance(value, float): + # value = int(value) + # elif self.obj_type == float and isinstance(value, int): + # value = float(value) + instance._dirty = True # Mark as dirty if not isinstance(value, self.obj_type): # Try casting to load values (maybe a namedtuple, i.e.) try: if isinstance(value, collections.abc.Mapping): - value = self.obj_type(**value) # If a dict, try to cast it to the object - elif isinstance(value, collections.abc.Iterable): # IF a list, tuple, etc... try to cast it - value = self.obj_type(*value) + value = self.obj_type(**value) # If a dict, try to convert + elif isinstance(value, collections.abc.Iterable): # IF a list, tuple, etc... try to convert + # If inner type is an ObservableList, ensure to provider owner + # so dirty can be controlled on list modifications + if self.obj_type == _ObservableList: + value = typing.cast(T, _ObservableList(instance, value)) + else: + value = self.obj_type(*value) # Hope that obj_type knows how to convert else: # Maybe it has a constructor that accepts a single value or is a callable... value = typing.cast(typing.Callable[..., typing.Any], self.obj_type)(value) - except Exception: + except Exception as e: # Allow int to float conversion and viceversa - raise TypeError(f"Field {self.name} cannot be set to {value} (type {self.obj_type.__name__})") - if not hasattr(instance, '_fields'): - setattr(instance, '_fields', {}) - getattr(instance, '_fields')[self.name] = value + raise ValueError( + f"Field {self.name} cannot be set to {value} (type {self.obj_type.__name__})" + ) from e + instance._fields[self.name] = value def marshal(self, instance: 'AutoSerializable') -> bytes: """Basic marshalling of field @@ -328,7 +383,7 @@ class ListField(_SerializableField[list[T]], list[T]): default: typing.Union[list[T], collections.abc.Callable[[], list[T]]] = lambda: [], cast: typing.Optional[typing.Callable[[typing.Any], T]] = None, ): - super().__init__(list, default) + super().__init__(_ObservableList, default) self._cast = cast def marshal(self, instance: 'AutoSerializable') -> bytes: @@ -505,12 +560,14 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter): """ _fields: dict[str, typing.Any] # Values for the fields (serializable fields only ofc) + _dirty: bool serialization_version: int = 0 # So autoserializable classes can keep their version if needed def __init__(self): super().__init__() self._fields = {} + self._dirty = False def _autoserializable_fields(self) -> collections.abc.Iterator[tuple[str, _SerializableField[typing.Any]]]: """Returns an iterator over all fields in the class, including inherited ones @@ -569,7 +626,8 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter): _MarshalInfo(name=v.name, type_name=str(v.__class__.__name__), value=v.marshal(self)) for _, v in self._autoserializable_fields() ] - + self._dirty = False # Marshal resets dirty flag + # Serialized data is: # 2 bytes -> name length # 2 bytes -> type name length @@ -633,6 +691,8 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter): else: logger.debug('Field %s not found in unmarshalled data', v.name) v.__set__(self, v._default()) # Set default value + + self._dirty = False # Reset dirty flag after unmarshalling def as_dict(self) -> dict[str, typing.Any]: return {k: v.__get__(self) for k, v in self._autoserializable_fields()}