1
0
mirror of https://github.com/dkmstr/openuds.git synced 2025-12-11 20:32:27 +03:00

Refactor query filter grammar and update test cases: enhance query parsing logic, add support for new functions, and improve test descriptions for clarity.

This commit is contained in:
Adolfo Gómez García
2025-08-13 07:25:46 +02:00
parent 2454af7ec1
commit 355362956f
2 changed files with 248 additions and 131 deletions

View File

@@ -1,178 +1,218 @@
# pyright: reportUnknownMemberType=false
import typing import typing
import collections.abc import re
import contextvars import contextvars
import collections.abc
import logging
import lark
from lark import Lark, Transformer, Token logger = logging.getLogger(__name__)
# Simplified Odata grammar _QUERY_GRAMMAR: typing.Final[
_ODATA_GRAMMAR: typing.Final[
str str
] = r""" ] = r"""?start: expr
?start: expr
?expr: "not" expr -> not_expr ?expr: or_expr
| expr "and" expr -> and_expr
| expr "or" expr -> or_expr ?or_expr: and_expr
| or_expr "or" and_expr -> logical_or
?and_expr: not_expr
| and_expr "and" not_expr -> logical_and
?not_expr: comparison
| "not" not_expr -> unary_not
?comparison: value
| value OP value -> binary_expr
| "(" expr ")" -> paren_expr | "(" expr ")" -> paren_expr
| comparison
?func_name: CNAME value: field | ESCAPED_STRING | NUMBER | boolean | func_call
field: NAME
func_call: NAME "(" [ value ("," value)* ] ")"
boolean: "true" -> true
| "false" -> false
OP: "eq" | "gt" | "lt" | "ne" | "ge" | "le" OP: "eq" | "gt" | "lt" | "ne" | "ge" | "le"
?comparison: operand OP operand -> binary_op
| func_name "(" field "," value ")" -> func_op
?operand: value_expr
| value
?value_expr: field
| func_name "(" field ")" -> value_func
field: CNAME
value: ESCAPED_STRING | SIGNED_NUMBER
ESCAPED_STRING: /'[^']*'/ | /"[^"]*"/ ESCAPED_STRING: /'[^']*'/ | /"[^"]*"/
NAME: CNAME ("." CNAME)*
%import common.CNAME %import common.CNAME
%import common.SIGNED_NUMBER %import common.SIGNED_NUMBER -> NUMBER
%import common.WS %import common.WS
%ignore WS %ignore WS
""" """
# with open("lark1.lark", "r") as f:
# _QUERY_GRAMMAR = f.read()
_ODATA_PARSER_VAR: typing.Final[contextvars.ContextVar[Lark]] = contextvars.ContextVar("odata_parser") # The idea is that parser returns a function that can be used to filter a list of dictionaries
# So we ensure all returned functions have the same signature and can be composed together
# Note that value can receive function or final values, as it is composed of
# terminals and
_T_Result: typing.TypeAlias = collections.abc.Callable[[dict[str, typing.Any]], typing.Any]
_QUERY_PARSER_VAR: typing.Final[contextvars.ContextVar[lark.Lark]] = contextvars.ContextVar("query_parser")
_REMOVE_QUOTES_RE: typing.Final[typing.Pattern[str]] = re.compile(r"^(['\"])(.*)\1$")
_FUNCTIONS_PARAMS_NUM: dict[str, int] = {
'substringof': 2,
'startswith': 2,
'endswith': 2,
'indexof': 2,
'concat': 2,
'tolower': 1,
'toupper': 1,
'length': 1,
'year': 1,
'month': 1,
'day': 1,
}
# 🧠 Transformer: convert the tree into Python functions class QueryTransformer(lark.Transformer[typing.Any, _T_Result]):
class ODataTransformer(Transformer[typing.Any, typing.Any]): @lark.visitors.v_args(inline=True) # pyright: ignore
def value(self, token: list[Token]) -> typing.Any: def value(self, arg: lark.Token | str | int | float) -> _T_Result:
val = token[0] value: typing.Any = arg
if val.type == "ESCAPED_STRING": if isinstance(arg, lark.Token):
raw = val.value match arg.type:
if raw.startswith("'") and raw.endswith("'"): case 'ESCAPED_STRING':
return raw[1:-1] match = _REMOVE_QUOTES_RE.match(arg.value)
elif raw.startswith('"') and raw.endswith('"'): if not match:
return raw[1:-1] return arg.value
else: value = match.group(2)
raise ValueError(f"Formato de cadena no válido: {raw}") case 'NUMBER':
elif val.type == "SIGNED_NUMBER": value = float(arg.value) if '.' in arg.value else int(arg.value)
return float(val.value) if '.' in val.value else int(val.value) case 'BOOLEAN':
value = typing.cast(str, arg.value).lower() == 'true'
case _:
raise ValueError(f"Unexpected token type: {arg.type}")
elif isinstance(arg, typing.Callable):
return lambda obj: typing.cast(_T_Result, arg)(obj)
def field(self, token: list[Token]) -> collections.abc.Callable[[dict[str, typing.Any]], typing.Any]: return lambda _obj: value
field_name = token[0].value
return lambda item: item.get(field_name)
def binary_op(self, items: list[typing.Any]) -> collections.abc.Callable[[dict[str, typing.Any]], bool]: @lark.visitors.v_args(inline=True)
left, op_token, right = items def true(self) -> _T_Result:
return lambda obj: True
op = op_token.value if isinstance(op_token, Token) else op_token @lark.visitors.v_args(inline=True)
def false(self) -> _T_Result:
return lambda obj: False
def resolve(expr: typing.Any) -> typing.Callable[[dict[str, typing.Any]], typing.Any]: @lark.visitors.v_args(inline=True)
if callable(expr): def field(self, arg: lark.Token) -> _T_Result:
return expr def getter(obj: dict[str, typing.Any]) -> typing.Any:
else: for part in arg.value.split('.'):
return lambda _: expr # fixed value obj = obj.get(part, {})
return obj
left_fn = resolve(left) return getter
right_fn = resolve(right)
@lark.visitors.v_args(inline=True)
def binary_expr(self, left: _T_Result, op: typing.Any, right: _T_Result) -> _T_Result:
def _compare(left: str | int | float, right: str | int | float) -> int:
if type(left) != type(right):
# Convert both to string and compare
left = str(left)
right = str(right)
# 0 -> are equal
# <0 -> left is less than right
# >0 -> left is greater than right
if typing.cast(typing.Any, left) < typing.cast(typing.Any, right):
return -1
elif typing.cast(typing.Any, left) > typing.cast(typing.Any, right):
return 1
return 0
match op: match op:
case "eq": case "eq":
return lambda item: left_fn(item) == right_fn(item) return lambda item: _compare(left(item), right(item)) == 0
case "gt": case "gt":
return lambda item: left_fn(item) > right_fn(item) return lambda item: _compare(left(item), right(item)) > 0
case "lt": case "lt":
return lambda item: left_fn(item) < right_fn(item) return lambda item: _compare(left(item), right(item)) < 0
case "ne": case "ne":
return lambda item: left_fn(item) != right_fn(item) return lambda item: _compare(left(item), right(item)) != 0
case "ge": case "ge":
return lambda item: left_fn(item) >= right_fn(item) return lambda item: _compare(left(item), right(item)) >= 0
case "le": case "le":
return lambda item: left_fn(item) <= right_fn(item) return lambda item: _compare(left(item), right(item)) <= 0
case _: case _:
raise ValueError(f"Operador desconocido: {op}") raise ValueError(f"Unknown operator: {op}")
def func_op(self, items: list[typing.Any]) -> collections.abc.Callable[[dict[str, typing.Any]], typing.Any]: @lark.visitors.v_args(inline=True)
func_token, field_fn, value = items def logical_and(self, left: _T_Result, right: _T_Result) -> _T_Result:
func = func_token.value if isinstance(func_token, Token) else func_token
match func:
case "startswith":
return lambda item: str(field_fn(item)).startswith(value)
case "endswith":
return lambda item: str(field_fn(item)).endswith(value)
case "contains":
# TODO: allow dicts and lists?
return lambda item: str(value) in str(field_fn(item))
case _:
raise ValueError(f"Función desconocida: {func}")
def value_func(
self, items: list[typing.Any]
) -> collections.abc.Callable[[dict[str, typing.Any]], typing.Any]:
func_token, field_fn = items
func = func_token.value if isinstance(func_token, Token) else func_token
match func:
case "length":
return lambda item: len(str(field_fn(item)))
case "tolower":
return lambda item: str(field_fn(item)).lower()
case "toupper":
return lambda item: str(field_fn(item)).upper()
case "trim":
return lambda item: str(field_fn(item)).strip()
case _:
raise ValueError(f"Value disconnected: {func}")
def and_expr(
self, items: list[collections.abc.Callable[..., typing.Any]]
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
left, right = items
return lambda item: left(item) and right(item) return lambda item: left(item) and right(item)
def or_expr( @lark.visitors.v_args(inline=True)
self, items: list[collections.abc.Callable[..., typing.Any]] def logical_or(self, left: _T_Result, right: _T_Result) -> _T_Result:
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
left, right = items
return lambda item: left(item) or right(item) return lambda item: left(item) or right(item)
def not_expr( @lark.visitors.v_args(inline=True)
self, items: list[collections.abc.Callable[..., typing.Any]] def unary_not(self, expr: _T_Result) -> _T_Result:
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
expr = items[0]
return lambda item: not expr(item) return lambda item: not expr(item)
def expr( @lark.visitors.v_args(inline=True)
self, items: list[collections.abc.Callable[[dict[str, typing.Any]], bool]] def paren_expr(self, expr: _T_Result) -> _T_Result:
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]: return expr
return items[0]
def paren_expr( @lark.visitors.v_args(inline=True)
self, items: list[collections.abc.Callable[[dict[str, typing.Any]], bool]] def func_call(self, func: lark.Token, *args: _T_Result) -> _T_Result:
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]: func_name = func.value.lower()
return items[0] # If unknown function, raise an error
if func_name not in _FUNCTIONS_PARAMS_NUM:
raise ValueError(f"Unknown function: {func.value}")
if len(args) != _FUNCTIONS_PARAMS_NUM[func_name]:
raise ValueError(
f"{func_name} function requires exactly {_FUNCTIONS_PARAMS_NUM[func_name]} arguments"
)
match func_name:
case 'substringof':
return lambda obj: str(args[1](obj)).find(str(args[0](obj))) != -1
case 'startswith':
return lambda obj: str(args[0](obj)).startswith(str(args[1](obj)))
case 'endswith':
return lambda obj: str(args[0](obj)).endswith(str(args[1](obj)))
case 'indexof':
return lambda obj: str(args[0](obj)).find(str(args[1](obj)))
case 'concat':
return lambda obj: str(args[0](obj)) + str(args[1](obj))
case 'length':
return lambda obj: len(str(args[0](obj)))
case 'tolower':
return lambda obj: str(args[0](obj)).lower()
case 'toupper':
return lambda obj: str(args[0](obj)).upper()
case 'year':
return lambda obj: str(args[0](obj)).split('-')[0] if isinstance(args[0](obj), str) else ''
case 'month':
return lambda obj: str(args[0](obj)).split('-')[1] if isinstance(args[0](obj), str) else ''
case 'day':
return lambda obj: str(args[0](obj)).split('-')[2] if isinstance(args[0](obj), str) else ''
case _:
# Will never reach this, as it has been already
raise ValueError(f"Unknown function: {func.value}")
# _odata_parser: typing.Final[Lark] = Lark(_ODATA_GRAMMAR, parser="lalr", transformer=ODataTransformer()) def get_parser() -> lark.Lark:
def get_parser() -> Lark:
try: try:
return _ODATA_PARSER_VAR.get() return _QUERY_PARSER_VAR.get()
except LookupError: except LookupError:
parser = Lark(_ODATA_GRAMMAR, parser="lalr", transformer=ODataTransformer()) parser = lark.Lark(_QUERY_GRAMMAR, parser="lalr", transformer=QueryTransformer())
_ODATA_PARSER_VAR.set(parser) _QUERY_PARSER_VAR.set(parser)
return parser return parser
# filter
def exec_filter(data: list[dict[str, typing.Any]], query: str) -> typing.Iterable[dict[str, typing.Any]]: def exec_filter(data: list[dict[str, typing.Any]], query: str) -> typing.Iterable[dict[str, typing.Any]]:
try: try:
filter_func = typing.cast( filter_func = typing.cast(_T_Result, get_parser().parse(query))
collections.abc.Callable[[dict[str, typing.Any]], bool], get_parser().parse(query)
)
return filter(filter_func, data) return filter(filter_func, data)
except Exception as e: except Exception as e:
raise ValueError(f"Error al procesar la query OData: {e}") raise ValueError(f"Error processing query: {e}") from None

View File

@@ -53,10 +53,10 @@ class TestQueryFilter(unittest.TestCase):
def test_grouped_expression_with_parentheses(self): def test_grouped_expression_with_parentheses(self):
query = "not (age gt 30 or name eq 'Bob')" query = "not (age gt 30 or name eq 'Bob')"
result = list(exec_filter(self.data, query)) result = list(exec_filter(self.data, query))
# Esperamos solo a Alice y David, porque: # We expect:
# - Charlie tiene age > 30 → excluido # - Charlie has age > 30 → excluded
# - Bob tiene name eq 'Bob' → excluido # - Bob has name eq 'Bob' → excluded
# - Alice y David tienen age == 30 y name != 'Bob' → incluidos # - Alice and David have age == 30 and name != 'Bob' → included
expected = [ expected = [
{"name": "Alice", "age": 30}, {"name": "Alice", "age": 30},
{"name": "David", "age": 30}, {"name": "David", "age": 30},
@@ -73,3 +73,80 @@ class TestQueryFilter(unittest.TestCase):
result = list(exec_filter(self.data, "length(name) eq 5")) result = list(exec_filter(self.data, "length(name) eq 5"))
expected = [{"name": "Alice", "age": 30}, {"name": "David", "age": 30}] expected = [{"name": "Alice", "age": 30}, {"name": "David", "age": 30}]
self.assertEqual(result, expected) self.assertEqual(result, expected)
def test_toupper_function(self):
result = list(exec_filter(self.data, "toupper(name) eq 'ALICE'"))
expected = [{"name": "Alice", "age": 30}]
self.assertEqual(result, expected)
def test_tolower_function(self):
result = list(exec_filter(self.data, "tolower(name) eq 'david'"))
expected = [{"name": "David", "age": 30}]
self.assertEqual(result, expected)
def test_concat_function(self):
data = [
{"first": "John", "last": "Doe"},
{"first": "Jane", "last": "Smith"},
]
result = list(exec_filter(data, "concat(first,last) eq 'JohnDoe'"))
expected = [{"first": "John", "last": "Doe"}]
self.assertEqual(result, expected)
def test_indexof_function(self):
result = list(exec_filter(self.data, "indexof(name,'a') ge 0"))
expected = [
{"name": "Charlie", "age": 35},
{"name": "David", "age": 30},
]
self.assertEqual(result, expected)
def test_substringof_function(self):
result = list(exec_filter(self.data, "substringof('li',name)"))
expected = [{"name": "Alice", "age": 30}, {"name": "Charlie", "age": 35}]
self.assertEqual(result, expected)
def test_year_function(self):
data = [
{"dob": "1990-05-12"},
{"dob": "1985-11-30"},
]
result = list(exec_filter(data, "year(dob) eq '1990'"))
expected = [{"dob": "1990-05-12"}]
self.assertEqual(result, expected)
def test_month_function(self):
data = [
{"dob": "1990-05-12"},
{"dob": "1985-11-30"},
]
result = list(exec_filter(data, "month(dob) eq '11'"))
expected = [{"dob": "1985-11-30"}]
self.assertEqual(result, expected)
def test_day_function(self):
data = [
{"dob": "1990-05-12"},
{"dob": "1985-11-30"},
]
result = list(exec_filter(data, "day(dob) eq '12'"))
expected = [{"dob": "1990-05-12"}]
self.assertEqual(result, expected)
def test_nested_field_access(self):
data = [
{"user": {"name": "Alice"}},
{"user": {"name": "Bob"}},
]
result = list(exec_filter(data, "user.name eq 'Bob'"))
expected = [{"user": {"name": "Bob"}}]
self.assertEqual(result, expected)
def test_not_with_parentheses(self):
result = list(exec_filter(self.data, "not (name eq 'Alice')"))
expected = [
{"name": "Bob", "age": 25},
{"name": "Charlie", "age": 35},
{"name": "David", "age": 30},
]
self.assertEqual(result, expected)