mirror of
https://github.com/dkmstr/openuds.git
synced 2025-02-08 05:57:39 +03:00
Adding support for controlling dirty state for autoserializable.
Added support for list fields
This commit is contained in:
parent
b0cf8c5ddf
commit
8c5e2d6552
@ -343,22 +343,69 @@ class AutoSerializable(UDSTestCase):
|
||||
|
||||
def test_autoserializable_dirty(self) -> None:
|
||||
instance = AutoSerializableClass()
|
||||
self.assertFalse(instance.is_dirty())
|
||||
|
||||
self.assertFalse(instance._dirty)
|
||||
|
||||
# Test list field dirty flag
|
||||
self.assertEqual(instance.list_field[0], 1)
|
||||
# First access sets default value, so it's dirty
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance._dirty = False
|
||||
self.assertEqual(instance.list_field[1], 2)
|
||||
# Second access to ANY value does not set dirty flag because
|
||||
self.assertFalse(instance._dirty)
|
||||
|
||||
instance.list_field = [3, 5, 7]
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance._dirty = False
|
||||
instance.list_field[0] = 1
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance._dirty = False
|
||||
instance.list_field.append(9)
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance._dirty = False
|
||||
instance.list_field.remove(5)
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance._dirty = False
|
||||
instance.list_field.pop()
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance._dirty = False
|
||||
instance.list_field.insert(1, 4)
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance._dirty = False
|
||||
instance.list_field.clear()
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance.list_field = [1, 2, 3]
|
||||
instance._dirty = False
|
||||
del instance.list_field[1]
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance._dirty = False
|
||||
instance.list_field.extend([4, 5])
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance._dirty = False
|
||||
instance.int_field = 1
|
||||
self.assertTrue(instance.is_dirty())
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance.marshal() # should reset dirty flag
|
||||
self.assertFalse(instance.is_dirty())
|
||||
self.assertFalse(instance._dirty)
|
||||
|
||||
instance.int_field = 1
|
||||
self.assertTrue(instance.is_dirty())
|
||||
self.assertTrue(instance._dirty)
|
||||
|
||||
instance2 = AutoSerializableClass()
|
||||
self.assertFalse(instance2.is_dirty())
|
||||
self.assertFalse(instance2._dirty)
|
||||
|
||||
instance2.int_field = 22
|
||||
self.assertTrue(instance2.is_dirty())
|
||||
self.assertTrue(instance2._dirty)
|
||||
|
||||
instance2.unmarshal(instance.marshal())
|
||||
self.assertFalse(instance2.is_dirty())
|
||||
self.assertFalse(instance2._dirty)
|
||||
|
@ -68,7 +68,7 @@ class _Unassigned:
|
||||
|
||||
|
||||
# means field has no default value
|
||||
UNASSIGNED = _Unassigned()
|
||||
UNASSIGNED: typing.Final[_Unassigned] = _Unassigned()
|
||||
|
||||
T = typing.TypeVar('T')
|
||||
V = typing.TypeVar('V')
|
||||
@ -95,6 +95,59 @@ VERSION_SIZE: typing.Final[int] = 2 # 2 bytes for version
|
||||
PACKED_LENGHS: typing.Final[struct.Struct] = struct.Struct('<HHI')
|
||||
|
||||
|
||||
class _ObservableList(list[T]):
|
||||
_owner: 'AutoSerializable'
|
||||
|
||||
def __init__(self, owner: 'AutoSerializable', *args: typing.Any):
|
||||
self._owner = owner
|
||||
self._owner._dirty = True
|
||||
super().__init__(*args)
|
||||
|
||||
def __setitem__(self, key: typing.SupportsIndex | slice, value: T | collections.abc.Iterable[T]) -> None:
|
||||
self._owner._dirty = True
|
||||
super().__setitem__(key, value) # type: ignore
|
||||
|
||||
def append(self, object: T, /) -> None:
|
||||
self._owner._dirty = True
|
||||
super().append(object)
|
||||
|
||||
def extend(self, iterable: collections.abc.Iterable[T], /) -> None:
|
||||
self._owner._dirty = True
|
||||
super().extend(iterable)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._owner._dirty = True
|
||||
super().clear()
|
||||
|
||||
def pop(self, index: typing.SupportsIndex = -1, /) -> T:
|
||||
self._owner._dirty = True
|
||||
return super().pop(index)
|
||||
|
||||
def insert(self, index: typing.SupportsIndex, object: T, /) -> None:
|
||||
self._owner._dirty = True
|
||||
super().insert(index, object)
|
||||
|
||||
def remove(self, value: T, /) -> None:
|
||||
self._owner._dirty = True
|
||||
super().remove(value)
|
||||
|
||||
def sort(self, *args: typing.Any, **kwargs: typing.Any) -> None:
|
||||
self._owner._dirty = True
|
||||
super().sort(*args, **kwargs)
|
||||
|
||||
def __delitem__(self, key: typing.SupportsIndex | slice, /) -> None:
|
||||
self._owner._dirty = True
|
||||
super().__delitem__(key)
|
||||
|
||||
def __iadd__(self, value: collections.abc.Iterable[T], /) -> typing.Self:
|
||||
self._owner._dirty = True
|
||||
return super().__iadd__(value)
|
||||
|
||||
def __imul__(self, value: typing.SupportsIndex, /) -> typing.Self:
|
||||
self._owner._dirty = True
|
||||
return super().__imul__(value)
|
||||
|
||||
|
||||
# Helper functions
|
||||
def fernet_key(crypt_key: bytes) -> str:
|
||||
"""Generate fermet key a crypt key
|
||||
@ -212,38 +265,40 @@ class _SerializableField(typing.Generic[T]):
|
||||
instance {SerializableFields} -- Instance of class with field
|
||||
|
||||
"""
|
||||
if hasattr(instance, '_fields'):
|
||||
if self.name in getattr(instance, '_fields'):
|
||||
return getattr(instance, '_fields')[self.name]
|
||||
if self.name not in instance._fields:
|
||||
# Set default using setter
|
||||
self.__set__(instance, self._default())
|
||||
|
||||
if self.default is None:
|
||||
raise AttributeError(f"Field {self.name} is not set")
|
||||
# Set default using setter
|
||||
self.__set__(instance, self._default())
|
||||
return getattr(instance, '_fields')[self.name]
|
||||
return instance._fields[self.name]
|
||||
|
||||
def __set__(self, instance: 'AutoSerializable', value: T) -> None:
|
||||
# If type is float and value is int, convert it
|
||||
# Or if type is int and value is float, convert it
|
||||
if typing.cast(typing.Type[typing.Any], self.obj_type) in (float, int) and isinstance(
|
||||
value, (float, int)
|
||||
):
|
||||
value = self.obj_type(value) # type: ignore
|
||||
# if self.obj_type == int and isinstance(value, float):
|
||||
# value = int(value)
|
||||
# elif self.obj_type == float and isinstance(value, int):
|
||||
# value = float(value)
|
||||
instance._dirty = True # Mark as dirty
|
||||
if not isinstance(value, self.obj_type):
|
||||
# Try casting to load values (maybe a namedtuple, i.e.)
|
||||
try:
|
||||
if isinstance(value, collections.abc.Mapping):
|
||||
value = self.obj_type(**value) # If a dict, try to cast it to the object
|
||||
elif isinstance(value, collections.abc.Iterable): # IF a list, tuple, etc... try to cast it
|
||||
value = self.obj_type(*value)
|
||||
value = self.obj_type(**value) # If a dict, try to convert
|
||||
elif isinstance(value, collections.abc.Iterable): # IF a list, tuple, etc... try to convert
|
||||
# If inner type is an ObservableList, ensure to provider owner
|
||||
# so dirty can be controlled on list modifications
|
||||
if self.obj_type == _ObservableList:
|
||||
value = typing.cast(T, _ObservableList(instance, value))
|
||||
else:
|
||||
value = self.obj_type(*value) # Hope that obj_type knows how to convert
|
||||
else: # Maybe it has a constructor that accepts a single value or is a callable...
|
||||
value = typing.cast(typing.Callable[..., typing.Any], self.obj_type)(value)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# Allow int to float conversion and viceversa
|
||||
raise TypeError(f"Field {self.name} cannot be set to {value} (type {self.obj_type.__name__})")
|
||||
if not hasattr(instance, '_fields'):
|
||||
setattr(instance, '_fields', {})
|
||||
getattr(instance, '_fields')[self.name] = value
|
||||
raise ValueError(
|
||||
f"Field {self.name} cannot be set to {value} (type {self.obj_type.__name__})"
|
||||
) from e
|
||||
instance._fields[self.name] = value
|
||||
|
||||
def marshal(self, instance: 'AutoSerializable') -> bytes:
|
||||
"""Basic marshalling of field
|
||||
@ -328,7 +383,7 @@ class ListField(_SerializableField[list[T]], list[T]):
|
||||
default: typing.Union[list[T], collections.abc.Callable[[], list[T]]] = lambda: [],
|
||||
cast: typing.Optional[typing.Callable[[typing.Any], T]] = None,
|
||||
):
|
||||
super().__init__(list, default)
|
||||
super().__init__(_ObservableList, default)
|
||||
self._cast = cast
|
||||
|
||||
def marshal(self, instance: 'AutoSerializable') -> bytes:
|
||||
@ -505,12 +560,14 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
|
||||
"""
|
||||
|
||||
_fields: dict[str, typing.Any] # Values for the fields (serializable fields only ofc)
|
||||
_dirty: bool
|
||||
|
||||
serialization_version: int = 0 # So autoserializable classes can keep their version if needed
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._fields = {}
|
||||
self._dirty = False
|
||||
|
||||
def _autoserializable_fields(self) -> collections.abc.Iterator[tuple[str, _SerializableField[typing.Any]]]:
|
||||
"""Returns an iterator over all fields in the class, including inherited ones
|
||||
@ -569,7 +626,8 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
|
||||
_MarshalInfo(name=v.name, type_name=str(v.__class__.__name__), value=v.marshal(self))
|
||||
for _, v in self._autoserializable_fields()
|
||||
]
|
||||
|
||||
self._dirty = False # Marshal resets dirty flag
|
||||
|
||||
# Serialized data is:
|
||||
# 2 bytes -> name length
|
||||
# 2 bytes -> type name length
|
||||
@ -633,6 +691,8 @@ class AutoSerializable(Serializable, metaclass=_FieldNameSetter):
|
||||
else:
|
||||
logger.debug('Field %s not found in unmarshalled data', v.name)
|
||||
v.__set__(self, v._default()) # Set default value
|
||||
|
||||
self._dirty = False # Reset dirty flag after unmarshalling
|
||||
|
||||
def as_dict(self) -> dict[str, typing.Any]:
|
||||
return {k: v.__get__(self) for k, v in self._autoserializable_fields()}
|
||||
|
Loading…
x
Reference in New Issue
Block a user