1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-10-07 15:33:51 +03:00

Refactor query filter grammar and update test cases: enhance query parsing logic, add support for new functions, and improve test descriptions for clarity.

This commit is contained in:
Adolfo Gómez García
2025-08-13 07:25:46 +02:00
parent 2454af7ec1
commit 355362956f
2 changed files with 248 additions and 131 deletions

View File

@@ -1,178 +1,218 @@
# pyright: reportUnknownMemberType=false
import typing
import collections.abc
import re
import contextvars
import collections.abc
import logging
import lark
from lark import Lark, Transformer, Token
logger = logging.getLogger(__name__)
# Simplified Odata grammar
_ODATA_GRAMMAR: typing.Final[
_QUERY_GRAMMAR: typing.Final[
str
] = r"""
?start: expr
] = r"""?start: expr
?expr: "not" expr -> not_expr
| expr "and" expr -> and_expr
| expr "or" expr -> or_expr
| "(" expr ")" -> paren_expr
| comparison
?expr: or_expr
?func_name: CNAME
?or_expr: and_expr
| or_expr "or" and_expr -> logical_or
?and_expr: not_expr
| and_expr "and" not_expr -> logical_and
?not_expr: comparison
| "not" not_expr -> unary_not
?comparison: value
| value OP value -> binary_expr
| "(" expr ")" -> paren_expr
value: field | ESCAPED_STRING | NUMBER | boolean | func_call
field: NAME
func_call: NAME "(" [ value ("," value)* ] ")"
boolean: "true" -> true
| "false" -> false
OP: "eq" | "gt" | "lt" | "ne" | "ge" | "le"
?comparison: operand OP operand -> binary_op
| func_name "(" field "," value ")" -> func_op
?operand: value_expr
| value
?value_expr: field
| func_name "(" field ")" -> value_func
field: CNAME
value: ESCAPED_STRING | SIGNED_NUMBER
ESCAPED_STRING: /'[^']*'/ | /"[^"]*"/
NAME: CNAME ("." CNAME)*
%import common.CNAME
%import common.SIGNED_NUMBER
%import common.SIGNED_NUMBER -> NUMBER
%import common.WS
%ignore WS
"""
# with open("lark1.lark", "r") as f:
# _QUERY_GRAMMAR = f.read()
_ODATA_PARSER_VAR: typing.Final[contextvars.ContextVar[Lark]] = contextvars.ContextVar("odata_parser")
# The idea is that parser returns a function that can be used to filter a list of dictionaries
# So we ensure all returned functions have the same signature and can be composed together
# Note that value can receive function or final values, as it is composed of
# terminals and
_T_Result: typing.TypeAlias = collections.abc.Callable[[dict[str, typing.Any]], typing.Any]
_QUERY_PARSER_VAR: typing.Final[contextvars.ContextVar[lark.Lark]] = contextvars.ContextVar("query_parser")
_REMOVE_QUOTES_RE: typing.Final[typing.Pattern[str]] = re.compile(r"^(['\"])(.*)\1$")
_FUNCTIONS_PARAMS_NUM: dict[str, int] = {
'substringof': 2,
'startswith': 2,
'endswith': 2,
'indexof': 2,
'concat': 2,
'tolower': 1,
'toupper': 1,
'length': 1,
'year': 1,
'month': 1,
'day': 1,
}
# 🧠 Transformer: convert the tree into Python functions
class ODataTransformer(Transformer[typing.Any, typing.Any]):
def value(self, token: list[Token]) -> typing.Any:
val = token[0]
if val.type == "ESCAPED_STRING":
raw = val.value
if raw.startswith("'") and raw.endswith("'"):
return raw[1:-1]
elif raw.startswith('"') and raw.endswith('"'):
return raw[1:-1]
else:
raise ValueError(f"Formato de cadena no válido: {raw}")
elif val.type == "SIGNED_NUMBER":
return float(val.value) if '.' in val.value else int(val.value)
class QueryTransformer(lark.Transformer[typing.Any, _T_Result]):
@lark.visitors.v_args(inline=True) # pyright: ignore
def value(self, arg: lark.Token | str | int | float) -> _T_Result:
value: typing.Any = arg
if isinstance(arg, lark.Token):
match arg.type:
case 'ESCAPED_STRING':
match = _REMOVE_QUOTES_RE.match(arg.value)
if not match:
return arg.value
value = match.group(2)
case 'NUMBER':
value = float(arg.value) if '.' in arg.value else int(arg.value)
case 'BOOLEAN':
value = typing.cast(str, arg.value).lower() == 'true'
case _:
raise ValueError(f"Unexpected token type: {arg.type}")
elif isinstance(arg, typing.Callable):
return lambda obj: typing.cast(_T_Result, arg)(obj)
def field(self, token: list[Token]) -> collections.abc.Callable[[dict[str, typing.Any]], typing.Any]:
field_name = token[0].value
return lambda item: item.get(field_name)
return lambda _obj: value
def binary_op(self, items: list[typing.Any]) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
left, op_token, right = items
@lark.visitors.v_args(inline=True)
def true(self) -> _T_Result:
return lambda obj: True
op = op_token.value if isinstance(op_token, Token) else op_token
@lark.visitors.v_args(inline=True)
def false(self) -> _T_Result:
return lambda obj: False
def resolve(expr: typing.Any) -> typing.Callable[[dict[str, typing.Any]], typing.Any]:
if callable(expr):
return expr
else:
return lambda _: expr # fixed value
@lark.visitors.v_args(inline=True)
def field(self, arg: lark.Token) -> _T_Result:
def getter(obj: dict[str, typing.Any]) -> typing.Any:
for part in arg.value.split('.'):
obj = obj.get(part, {})
return obj
left_fn = resolve(left)
right_fn = resolve(right)
return getter
@lark.visitors.v_args(inline=True)
def binary_expr(self, left: _T_Result, op: typing.Any, right: _T_Result) -> _T_Result:
def _compare(left: str | int | float, right: str | int | float) -> int:
if type(left) != type(right):
# Convert both to string and compare
left = str(left)
right = str(right)
# 0 -> are equal
# <0 -> left is less than right
# >0 -> left is greater than right
if typing.cast(typing.Any, left) < typing.cast(typing.Any, right):
return -1
elif typing.cast(typing.Any, left) > typing.cast(typing.Any, right):
return 1
return 0
match op:
case "eq":
return lambda item: left_fn(item) == right_fn(item)
return lambda item: _compare(left(item), right(item)) == 0
case "gt":
return lambda item: left_fn(item) > right_fn(item)
return lambda item: _compare(left(item), right(item)) > 0
case "lt":
return lambda item: left_fn(item) < right_fn(item)
return lambda item: _compare(left(item), right(item)) < 0
case "ne":
return lambda item: left_fn(item) != right_fn(item)
return lambda item: _compare(left(item), right(item)) != 0
case "ge":
return lambda item: left_fn(item) >= right_fn(item)
return lambda item: _compare(left(item), right(item)) >= 0
case "le":
return lambda item: left_fn(item) <= right_fn(item)
return lambda item: _compare(left(item), right(item)) <= 0
case _:
raise ValueError(f"Operador desconocido: {op}")
raise ValueError(f"Unknown operator: {op}")
def func_op(self, items: list[typing.Any]) -> collections.abc.Callable[[dict[str, typing.Any]], typing.Any]:
func_token, field_fn, value = items
func = func_token.value if isinstance(func_token, Token) else func_token
match func:
case "startswith":
return lambda item: str(field_fn(item)).startswith(value)
case "endswith":
return lambda item: str(field_fn(item)).endswith(value)
case "contains":
# TODO: allow dicts and lists?
return lambda item: str(value) in str(field_fn(item))
case _:
raise ValueError(f"Función desconocida: {func}")
def value_func(
self, items: list[typing.Any]
) -> collections.abc.Callable[[dict[str, typing.Any]], typing.Any]:
func_token, field_fn = items
func = func_token.value if isinstance(func_token, Token) else func_token
match func:
case "length":
return lambda item: len(str(field_fn(item)))
case "tolower":
return lambda item: str(field_fn(item)).lower()
case "toupper":
return lambda item: str(field_fn(item)).upper()
case "trim":
return lambda item: str(field_fn(item)).strip()
case _:
raise ValueError(f"Value disconnected: {func}")
def and_expr(
self, items: list[collections.abc.Callable[..., typing.Any]]
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
left, right = items
@lark.visitors.v_args(inline=True)
def logical_and(self, left: _T_Result, right: _T_Result) -> _T_Result:
return lambda item: left(item) and right(item)
def or_expr(
self, items: list[collections.abc.Callable[..., typing.Any]]
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
left, right = items
@lark.visitors.v_args(inline=True)
def logical_or(self, left: _T_Result, right: _T_Result) -> _T_Result:
return lambda item: left(item) or right(item)
def not_expr(
self, items: list[collections.abc.Callable[..., typing.Any]]
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
expr = items[0]
@lark.visitors.v_args(inline=True)
def unary_not(self, expr: _T_Result) -> _T_Result:
return lambda item: not expr(item)
def expr(
self, items: list[collections.abc.Callable[[dict[str, typing.Any]], bool]]
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
return items[0]
@lark.visitors.v_args(inline=True)
def paren_expr(self, expr: _T_Result) -> _T_Result:
return expr
def paren_expr(
self, items: list[collections.abc.Callable[[dict[str, typing.Any]], bool]]
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
return items[0]
@lark.visitors.v_args(inline=True)
def func_call(self, func: lark.Token, *args: _T_Result) -> _T_Result:
func_name = func.value.lower()
# If unknown function, raise an error
if func_name not in _FUNCTIONS_PARAMS_NUM:
raise ValueError(f"Unknown function: {func.value}")
if len(args) != _FUNCTIONS_PARAMS_NUM[func_name]:
raise ValueError(
f"{func_name} function requires exactly {_FUNCTIONS_PARAMS_NUM[func_name]} arguments"
)
match func_name:
case 'substringof':
return lambda obj: str(args[1](obj)).find(str(args[0](obj))) != -1
case 'startswith':
return lambda obj: str(args[0](obj)).startswith(str(args[1](obj)))
case 'endswith':
return lambda obj: str(args[0](obj)).endswith(str(args[1](obj)))
case 'indexof':
return lambda obj: str(args[0](obj)).find(str(args[1](obj)))
case 'concat':
return lambda obj: str(args[0](obj)) + str(args[1](obj))
case 'length':
return lambda obj: len(str(args[0](obj)))
case 'tolower':
return lambda obj: str(args[0](obj)).lower()
case 'toupper':
return lambda obj: str(args[0](obj)).upper()
case 'year':
return lambda obj: str(args[0](obj)).split('-')[0] if isinstance(args[0](obj), str) else ''
case 'month':
return lambda obj: str(args[0](obj)).split('-')[1] if isinstance(args[0](obj), str) else ''
case 'day':
return lambda obj: str(args[0](obj)).split('-')[2] if isinstance(args[0](obj), str) else ''
case _:
# Will never reach this, as it has been already
raise ValueError(f"Unknown function: {func.value}")
# _odata_parser: typing.Final[Lark] = Lark(_ODATA_GRAMMAR, parser="lalr", transformer=ODataTransformer())
def get_parser() -> Lark:
def get_parser() -> lark.Lark:
try:
return _ODATA_PARSER_VAR.get()
return _QUERY_PARSER_VAR.get()
except LookupError:
parser = Lark(_ODATA_GRAMMAR, parser="lalr", transformer=ODataTransformer())
_ODATA_PARSER_VAR.set(parser)
parser = lark.Lark(_QUERY_GRAMMAR, parser="lalr", transformer=QueryTransformer())
_QUERY_PARSER_VAR.set(parser)
return parser
# filter
def exec_filter(data: list[dict[str, typing.Any]], query: str) -> typing.Iterable[dict[str, typing.Any]]:
try:
filter_func = typing.cast(
collections.abc.Callable[[dict[str, typing.Any]], bool], get_parser().parse(query)
)
filter_func = typing.cast(_T_Result, get_parser().parse(query))
return filter(filter_func, data)
except Exception as e:
raise ValueError(f"Error al procesar la query OData: {e}")
raise ValueError(f"Error processing query: {e}") from None

View File

@@ -53,10 +53,10 @@ class TestQueryFilter(unittest.TestCase):
def test_grouped_expression_with_parentheses(self):
query = "not (age gt 30 or name eq 'Bob')"
result = list(exec_filter(self.data, query))
# Esperamos solo a Alice y David, porque:
# - Charlie tiene age > 30 → excluido
# - Bob tiene name eq 'Bob' → excluido
# - Alice y David tienen age == 30 y name != 'Bob' → incluidos
# We expect:
# - Charlie has age > 30 → excluded
# - Bob has name eq 'Bob' → excluded
# - Alice and David have age == 30 and name != 'Bob' → included
expected = [
{"name": "Alice", "age": 30},
{"name": "David", "age": 30},
@@ -73,3 +73,80 @@ class TestQueryFilter(unittest.TestCase):
result = list(exec_filter(self.data, "length(name) eq 5"))
expected = [{"name": "Alice", "age": 30}, {"name": "David", "age": 30}]
self.assertEqual(result, expected)
def test_toupper_function(self):
result = list(exec_filter(self.data, "toupper(name) eq 'ALICE'"))
expected = [{"name": "Alice", "age": 30}]
self.assertEqual(result, expected)
def test_tolower_function(self):
result = list(exec_filter(self.data, "tolower(name) eq 'david'"))
expected = [{"name": "David", "age": 30}]
self.assertEqual(result, expected)
def test_concat_function(self):
data = [
{"first": "John", "last": "Doe"},
{"first": "Jane", "last": "Smith"},
]
result = list(exec_filter(data, "concat(first,last) eq 'JohnDoe'"))
expected = [{"first": "John", "last": "Doe"}]
self.assertEqual(result, expected)
def test_indexof_function(self):
result = list(exec_filter(self.data, "indexof(name,'a') ge 0"))
expected = [
{"name": "Charlie", "age": 35},
{"name": "David", "age": 30},
]
self.assertEqual(result, expected)
def test_substringof_function(self):
result = list(exec_filter(self.data, "substringof('li',name)"))
expected = [{"name": "Alice", "age": 30}, {"name": "Charlie", "age": 35}]
self.assertEqual(result, expected)
def test_year_function(self):
data = [
{"dob": "1990-05-12"},
{"dob": "1985-11-30"},
]
result = list(exec_filter(data, "year(dob) eq '1990'"))
expected = [{"dob": "1990-05-12"}]
self.assertEqual(result, expected)
def test_month_function(self):
data = [
{"dob": "1990-05-12"},
{"dob": "1985-11-30"},
]
result = list(exec_filter(data, "month(dob) eq '11'"))
expected = [{"dob": "1985-11-30"}]
self.assertEqual(result, expected)
def test_day_function(self):
data = [
{"dob": "1990-05-12"},
{"dob": "1985-11-30"},
]
result = list(exec_filter(data, "day(dob) eq '12'"))
expected = [{"dob": "1990-05-12"}]
self.assertEqual(result, expected)
def test_nested_field_access(self):
data = [
{"user": {"name": "Alice"}},
{"user": {"name": "Bob"}},
]
result = list(exec_filter(data, "user.name eq 'Bob'"))
expected = [{"user": {"name": "Bob"}}]
self.assertEqual(result, expected)
def test_not_with_parentheses(self):
result = list(exec_filter(self.data, "not (name eq 'Alice')"))
expected = [
{"name": "Bob", "age": 25},
{"name": "Charlie", "age": 35},
{"name": "David", "age": 30},
]
self.assertEqual(result, expected)