1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-02-08 05:57:39 +03:00

Adding support for controlling dirty state for autoserializable.

Added support for list fields
This commit is contained in:
Adolfo Gómez García 2024-07-11 10:26:47 +02:00
parent b0cf8c5ddf
commit 8c5e2d6552
No known key found for this signature in database
GPG Key ID: DD1ABF20724CDA23
2 changed files with 138 additions and 31 deletions

View File

@ -343,22 +343,69 @@ class AutoSerializable(UDSTestCase):
def test_autoserializable_dirty(self) -> None:
instance = AutoSerializableClass()
self.assertFalse(instance.is_dirty())
self.assertFalse(instance._dirty)
# Test list field dirty flag
self.assertEqual(instance.list_field[0], 1)
# First access sets default value, so it's dirty
self.assertTrue(instance._dirty)
instance._dirty = False
self.assertEqual(instance.list_field[1], 2)
# Second access to ANY value does not set dirty flag because
self.assertFalse(instance._dirty)
instance.list_field = [3, 5, 7]
self.assertTrue(instance._dirty)
instance._dirty = False
instance.list_field[0] = 1
self.assertTrue(instance._dirty)
instance._dirty = False
instance.list_field.append(9)
self.assertTrue(instance._dirty)
instance._dirty = False
instance.list_field.remove(5)
self.assertTrue(instance._dirty)
instance._dirty = False
instance.list_field.pop()
self.assertTrue(instance._dirty)
instance._dirty = False
instance.list_field.insert(1, 4)
self.assertTrue(instance._dirty)
instance._dirty = False
instance.list_field.clear()
self.assertTrue(instance._dirty)
instance.list_field = [1, 2, 3]
instance._dirty = False
del instance.list_field[1]
self.assertTrue(instance._dirty)
instance._dirty = False
instance.list_field.extend([4, 5])
self.assertTrue(instance._dirty)
instance._dirty = False
instance.int_field = 1
self.assertTrue(instance.is_dirty())
self.assertTrue(instance._dirty)
instance.marshal() # should reset dirty flag
self.assertFalse(instance.is_dirty())
self.assertFalse(instance._dirty)
instance.int_field = 1
self.assertTrue(instance.is_dirty())
self.assertTrue(instance._dirty)
instance2 = AutoSerializableClass()
self.assertFalse(instance2.is_dirty())
self.assertFalse(instance2._dirty)
instance2.int_field = 22
self.assertTrue(instance2.is_dirty())
self.assertTrue(instance2._dirty)
instance2.unmarshal(instance.marshal())
self.assertFalse(instance2.is_dirty())
self.assertFalse(instance2._dirty)

View File

@ -68,7 +68,7 @@ class _Unassigned:
# means field has no default value
UNASSIGNED = _Unassigned()
UNASSIGNED: typing.Final[_Unassigned] = _Unassigned()
T = typing.TypeVar('T')
V = typing.TypeVar('V')
@ -95,6 +95,59 @@ VERSION_SIZE: typing.Final[int] = 2 # 2 bytes for version
PACKED_LENGHS: typing.Final[struct.Struct] = struct.Struct('<HHI')
class _ObservableList(list[T]):
_owner: 'AutoSerializable'
def __init__(self, owner: 'AutoSerializable', *args: typing.Any):
self._owner = owner
self._owner._dirty = True
super().__init__(*args)
def __setitem__(self, key: typing.SupportsIndex | slice, value: T | collections.abc.Iterable[T]) -> None:
self._owner._dirty = True
super().__setitem__(key, value) # type: ignore
def append(self, object: T, /) -> None:
self._owner._dirty = True
super().append(object)
def extend(self, iterable: collections.abc.Iterable[T], /) -> None:
self._owner._dirty = True
super().extend(iterable)
def clear(self) -> None:
self._owner._dirty = True
super().clear()
def pop(self, index: typing.SupportsIndex = -1, /) -> T:
self._owner._dirty = True
return super().pop(index)
def insert(self, index: typing.SupportsIndex, object: T, /) -> None:
self._owner._dirty = True
super().insert(index, object)
def remove(self, value: T, /) -> None:
self._owner._dirty = True
super().remove(value)
def sort(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self._owner._dirty = True
super().sort(*args, **kwargs)
def __delitem__(self, key: typing.SupportsIndex | slice, /) -> None:
self._owner._dirty = True
super().__delitem__(key)
def __iadd__(self, value: collections.abc.Iterable[T], /) -> typing.Self:
self._owner._dirty = True
return super().__iadd__(value)
def __imul__(self, value: typing.SupportsIndex, /) -> typing.Self:
self._owner._dirty = True
return super().__imul__(value)
# Helper functions
def fernet_key(crypt_key: bytes) -> str:
"""Generate fermet key a crypt key
@ -212,38 +265,40 @@ class _SerializableField(typing.Generic[T]):
instance {SerializableFields} -- Instance of class with field
"""
if hasattr(instance, '_fields'):
if self.name in getattr(instance, '_fields'):
return getattr(instance, '_fields')[self.name]
if self.name not in instance._fields:
# Set default using setter
self.__set__(instance, self._default())
if self.default is None:
raise AttributeError(f"Field {self.name} is not set")
# Set default using setter
self.__set__(instance, self._default())
return getattr(instance, '_fields')[self.name]
return instance._fields[self.name]
def __set__(self, instance: 'AutoSerializable', value: T) -> None:
# If type is float and value is int, convert it
# Or if type is int and value is float, convert it
if typing.cast(typing.Type[typing.Any], self.obj_type) in (float, int) and isinstance(
value, (float, int)
):
value = self.obj_type(value) # type: ignore
# if self.obj_type == int and isinstance(value, float):
# value = int(value)
# elif self.obj_type == float and isinstance(value, int):
# value = float(value)
instance._dirty = True # Mark as dirty
if not isinstance(value, self.obj_type):
# Try casting to load values (maybe a namedtuple, i.e.)
try:
if isinstance(value, collections.abc.Mapping):
value = self.obj_type(**value) # If a dict, try to cast it to the object
elif isinstance(value, collections.abc.Iterable): # IF a list, tuple, etc... try to cast it
value = self.obj_type(*value)
value = self.obj_type(**value) # If a dict, try to convert
elif isinstance(value, collections.abc.Iterable): # IF a list, tuple, etc... try to convert
# If inner type is an ObservableList, ensure to provider owner
# so dirty can be controlled on list modifications
if self.obj_type == _ObservableList:
value = typing.cast(T, _ObservableList(instance, value))
else:
value = self.obj_type(*value) # Hope that obj_type knows how to convert
else: # Maybe it has a constructor that accepts a single value or is a callable...
value = typing.cast(typing.Callable[..., typing.Any], self.obj_type)(value)
except Exception:
except Exception as e:
# Allow int to float conversion and viceversa
raise TypeError(f"Field {self.name} cannot be set to {value} (type {self.obj_type.__name__})")
if not hasattr(instance, '_fields'):
setattr(instance, '_fields', {})
getattr(instance, '_fields')[self.name] = value
raise ValueError(
f"Field {self.name} cannot be set to {value} (type {self.obj_type.__name__})"
) from e
instance._fields[self.name] = value
def marshal(self, instance: 'AutoSerializable') -> bytes:
"""Basic marshalling of field
@ -328,7 +383,7 @@ class ListField(_SerializableField[list[T]], list[T]):
default: typing.Union[list[T], collections.abc.Callable[[], list[T]]] = lambda: [],
cast: typing.Optional[typing.Callable[[typing.Any], T]] = None,
):
super().__init__(list, default)
super().__init__(_ObservableList, default)
self._cast = cast
def marshal(self, instance: 'AutoSerializable') -> bytes:
@ -505,12 +560,14 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
"""
_fields: dict[str, typing.Any] # Values for the fields (serializable fields only ofc)
_dirty: bool
serialization_version: int = 0 # So autoserializable classes can keep their version if needed
def __init__(self):
super().__init__()
self._fields = {}
self._dirty = False
def _autoserializable_fields(self) -> collections.abc.Iterator[tuple[str, _SerializableField[typing.Any]]]:
"""Returns an iterator over all fields in the class, including inherited ones
@ -569,7 +626,8 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
_MarshalInfo(name=v.name, type_name=str(v.__class__.__name__), value=v.marshal(self))
for _, v in self._autoserializable_fields()
]
self._dirty = False # Marshal resets dirty flag
# Serialized data is:
# 2 bytes -> name length
# 2 bytes -> type name length
@ -633,6 +691,8 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
else:
logger.debug('Field %s not found in unmarshalled data', v.name)
v.__set__(self, v._default()) # Set default value
self._dirty = False # Reset dirty flag after unmarshalling
def as_dict(self) -> dict[str, typing.Any]:
return {k: v.__get__(self) for k, v in self._autoserializable_fields()}