mirror of
https://github.com/dkmstr/openuds.git
synced 2025-10-07 15:33:51 +03:00
Refactor query filter grammar and update test cases: enhance query parsing logic, add support for new functions, and improve test descriptions for clarity.
This commit is contained in:
@@ -1,178 +1,218 @@
|
||||
# pyright: reportUnknownMemberType=false
|
||||
import typing
|
||||
import collections.abc
|
||||
import re
|
||||
import contextvars
|
||||
import collections.abc
|
||||
import logging
|
||||
|
||||
import lark
|
||||
|
||||
|
||||
from lark import Lark, Transformer, Token
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Simplified Odata grammar
|
||||
_ODATA_GRAMMAR: typing.Final[
|
||||
_QUERY_GRAMMAR: typing.Final[
|
||||
str
|
||||
] = r"""
|
||||
?start: expr
|
||||
] = r"""?start: expr
|
||||
|
||||
?expr: "not" expr -> not_expr
|
||||
| expr "and" expr -> and_expr
|
||||
| expr "or" expr -> or_expr
|
||||
| "(" expr ")" -> paren_expr
|
||||
| comparison
|
||||
?expr: or_expr
|
||||
|
||||
?func_name: CNAME
|
||||
?or_expr: and_expr
|
||||
| or_expr "or" and_expr -> logical_or
|
||||
|
||||
?and_expr: not_expr
|
||||
| and_expr "and" not_expr -> logical_and
|
||||
|
||||
?not_expr: comparison
|
||||
| "not" not_expr -> unary_not
|
||||
|
||||
?comparison: value
|
||||
| value OP value -> binary_expr
|
||||
| "(" expr ")" -> paren_expr
|
||||
|
||||
value: field | ESCAPED_STRING | NUMBER | boolean | func_call
|
||||
|
||||
field: NAME
|
||||
|
||||
func_call: NAME "(" [ value ("," value)* ] ")"
|
||||
|
||||
boolean: "true" -> true
|
||||
| "false" -> false
|
||||
|
||||
OP: "eq" | "gt" | "lt" | "ne" | "ge" | "le"
|
||||
|
||||
?comparison: operand OP operand -> binary_op
|
||||
| func_name "(" field "," value ")" -> func_op
|
||||
|
||||
?operand: value_expr
|
||||
| value
|
||||
|
||||
?value_expr: field
|
||||
| func_name "(" field ")" -> value_func
|
||||
|
||||
field: CNAME
|
||||
value: ESCAPED_STRING | SIGNED_NUMBER
|
||||
|
||||
ESCAPED_STRING: /'[^']*'/ | /"[^"]*"/
|
||||
NAME: CNAME ("." CNAME)*
|
||||
|
||||
%import common.CNAME
|
||||
%import common.SIGNED_NUMBER
|
||||
%import common.SIGNED_NUMBER -> NUMBER
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
"""
|
||||
# with open("lark1.lark", "r") as f:
|
||||
# _QUERY_GRAMMAR = f.read()
|
||||
|
||||
_ODATA_PARSER_VAR: typing.Final[contextvars.ContextVar[Lark]] = contextvars.ContextVar("odata_parser")
|
||||
# The idea is that parser returns a function that can be used to filter a list of dictionaries
|
||||
# So we ensure all returned functions have the same signature and can be composed together
|
||||
# Note that value can receive function or final values, as it is composed of
|
||||
# terminals and
|
||||
_T_Result: typing.TypeAlias = collections.abc.Callable[[dict[str, typing.Any]], typing.Any]
|
||||
|
||||
_QUERY_PARSER_VAR: typing.Final[contextvars.ContextVar[lark.Lark]] = contextvars.ContextVar("query_parser")
|
||||
|
||||
_REMOVE_QUOTES_RE: typing.Final[typing.Pattern[str]] = re.compile(r"^(['\"])(.*)\1$")
|
||||
|
||||
_FUNCTIONS_PARAMS_NUM: dict[str, int] = {
|
||||
'substringof': 2,
|
||||
'startswith': 2,
|
||||
'endswith': 2,
|
||||
'indexof': 2,
|
||||
'concat': 2,
|
||||
'tolower': 1,
|
||||
'toupper': 1,
|
||||
'length': 1,
|
||||
'year': 1,
|
||||
'month': 1,
|
||||
'day': 1,
|
||||
}
|
||||
|
||||
|
||||
# 🧠 Transformer: convert the tree into Python functions
|
||||
class ODataTransformer(Transformer[typing.Any, typing.Any]):
|
||||
def value(self, token: list[Token]) -> typing.Any:
|
||||
val = token[0]
|
||||
if val.type == "ESCAPED_STRING":
|
||||
raw = val.value
|
||||
if raw.startswith("'") and raw.endswith("'"):
|
||||
return raw[1:-1]
|
||||
elif raw.startswith('"') and raw.endswith('"'):
|
||||
return raw[1:-1]
|
||||
else:
|
||||
raise ValueError(f"Formato de cadena no válido: {raw}")
|
||||
elif val.type == "SIGNED_NUMBER":
|
||||
return float(val.value) if '.' in val.value else int(val.value)
|
||||
class QueryTransformer(lark.Transformer[typing.Any, _T_Result]):
|
||||
@lark.visitors.v_args(inline=True) # pyright: ignore
|
||||
def value(self, arg: lark.Token | str | int | float) -> _T_Result:
|
||||
value: typing.Any = arg
|
||||
if isinstance(arg, lark.Token):
|
||||
match arg.type:
|
||||
case 'ESCAPED_STRING':
|
||||
match = _REMOVE_QUOTES_RE.match(arg.value)
|
||||
if not match:
|
||||
return arg.value
|
||||
value = match.group(2)
|
||||
case 'NUMBER':
|
||||
value = float(arg.value) if '.' in arg.value else int(arg.value)
|
||||
case 'BOOLEAN':
|
||||
value = typing.cast(str, arg.value).lower() == 'true'
|
||||
case _:
|
||||
raise ValueError(f"Unexpected token type: {arg.type}")
|
||||
elif isinstance(arg, typing.Callable):
|
||||
return lambda obj: typing.cast(_T_Result, arg)(obj)
|
||||
|
||||
def field(self, token: list[Token]) -> collections.abc.Callable[[dict[str, typing.Any]], typing.Any]:
|
||||
field_name = token[0].value
|
||||
return lambda item: item.get(field_name)
|
||||
return lambda _obj: value
|
||||
|
||||
def binary_op(self, items: list[typing.Any]) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
|
||||
left, op_token, right = items
|
||||
@lark.visitors.v_args(inline=True)
|
||||
def true(self) -> _T_Result:
|
||||
return lambda obj: True
|
||||
|
||||
op = op_token.value if isinstance(op_token, Token) else op_token
|
||||
@lark.visitors.v_args(inline=True)
|
||||
def false(self) -> _T_Result:
|
||||
return lambda obj: False
|
||||
|
||||
def resolve(expr: typing.Any) -> typing.Callable[[dict[str, typing.Any]], typing.Any]:
|
||||
if callable(expr):
|
||||
return expr
|
||||
else:
|
||||
return lambda _: expr # fixed value
|
||||
@lark.visitors.v_args(inline=True)
|
||||
def field(self, arg: lark.Token) -> _T_Result:
|
||||
def getter(obj: dict[str, typing.Any]) -> typing.Any:
|
||||
for part in arg.value.split('.'):
|
||||
obj = obj.get(part, {})
|
||||
return obj
|
||||
|
||||
left_fn = resolve(left)
|
||||
right_fn = resolve(right)
|
||||
return getter
|
||||
|
||||
@lark.visitors.v_args(inline=True)
|
||||
def binary_expr(self, left: _T_Result, op: typing.Any, right: _T_Result) -> _T_Result:
|
||||
def _compare(left: str | int | float, right: str | int | float) -> int:
|
||||
if type(left) != type(right):
|
||||
# Convert both to string and compare
|
||||
left = str(left)
|
||||
right = str(right)
|
||||
# 0 -> are equal
|
||||
# <0 -> left is less than right
|
||||
# >0 -> left is greater than right
|
||||
if typing.cast(typing.Any, left) < typing.cast(typing.Any, right):
|
||||
return -1
|
||||
elif typing.cast(typing.Any, left) > typing.cast(typing.Any, right):
|
||||
return 1
|
||||
return 0
|
||||
|
||||
match op:
|
||||
case "eq":
|
||||
return lambda item: left_fn(item) == right_fn(item)
|
||||
return lambda item: _compare(left(item), right(item)) == 0
|
||||
case "gt":
|
||||
return lambda item: left_fn(item) > right_fn(item)
|
||||
return lambda item: _compare(left(item), right(item)) > 0
|
||||
case "lt":
|
||||
return lambda item: left_fn(item) < right_fn(item)
|
||||
return lambda item: _compare(left(item), right(item)) < 0
|
||||
case "ne":
|
||||
return lambda item: left_fn(item) != right_fn(item)
|
||||
return lambda item: _compare(left(item), right(item)) != 0
|
||||
case "ge":
|
||||
return lambda item: left_fn(item) >= right_fn(item)
|
||||
return lambda item: _compare(left(item), right(item)) >= 0
|
||||
case "le":
|
||||
return lambda item: left_fn(item) <= right_fn(item)
|
||||
return lambda item: _compare(left(item), right(item)) <= 0
|
||||
case _:
|
||||
raise ValueError(f"Operador desconocido: {op}")
|
||||
raise ValueError(f"Unknown operator: {op}")
|
||||
|
||||
def func_op(self, items: list[typing.Any]) -> collections.abc.Callable[[dict[str, typing.Any]], typing.Any]:
|
||||
func_token, field_fn, value = items
|
||||
|
||||
func = func_token.value if isinstance(func_token, Token) else func_token
|
||||
|
||||
match func:
|
||||
case "startswith":
|
||||
return lambda item: str(field_fn(item)).startswith(value)
|
||||
case "endswith":
|
||||
return lambda item: str(field_fn(item)).endswith(value)
|
||||
case "contains":
|
||||
# TODO: allow dicts and lists?
|
||||
return lambda item: str(value) in str(field_fn(item))
|
||||
case _:
|
||||
raise ValueError(f"Función desconocida: {func}")
|
||||
|
||||
def value_func(
|
||||
self, items: list[typing.Any]
|
||||
) -> collections.abc.Callable[[dict[str, typing.Any]], typing.Any]:
|
||||
func_token, field_fn = items
|
||||
func = func_token.value if isinstance(func_token, Token) else func_token
|
||||
|
||||
match func:
|
||||
case "length":
|
||||
return lambda item: len(str(field_fn(item)))
|
||||
case "tolower":
|
||||
return lambda item: str(field_fn(item)).lower()
|
||||
case "toupper":
|
||||
return lambda item: str(field_fn(item)).upper()
|
||||
case "trim":
|
||||
return lambda item: str(field_fn(item)).strip()
|
||||
case _:
|
||||
raise ValueError(f"Value disconnected: {func}")
|
||||
|
||||
def and_expr(
|
||||
self, items: list[collections.abc.Callable[..., typing.Any]]
|
||||
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
|
||||
left, right = items
|
||||
@lark.visitors.v_args(inline=True)
|
||||
def logical_and(self, left: _T_Result, right: _T_Result) -> _T_Result:
|
||||
return lambda item: left(item) and right(item)
|
||||
|
||||
def or_expr(
|
||||
self, items: list[collections.abc.Callable[..., typing.Any]]
|
||||
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
|
||||
left, right = items
|
||||
@lark.visitors.v_args(inline=True)
|
||||
def logical_or(self, left: _T_Result, right: _T_Result) -> _T_Result:
|
||||
return lambda item: left(item) or right(item)
|
||||
|
||||
def not_expr(
|
||||
self, items: list[collections.abc.Callable[..., typing.Any]]
|
||||
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
|
||||
expr = items[0]
|
||||
@lark.visitors.v_args(inline=True)
|
||||
def unary_not(self, expr: _T_Result) -> _T_Result:
|
||||
return lambda item: not expr(item)
|
||||
|
||||
def expr(
|
||||
self, items: list[collections.abc.Callable[[dict[str, typing.Any]], bool]]
|
||||
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
|
||||
return items[0]
|
||||
@lark.visitors.v_args(inline=True)
|
||||
def paren_expr(self, expr: _T_Result) -> _T_Result:
|
||||
return expr
|
||||
|
||||
def paren_expr(
|
||||
self, items: list[collections.abc.Callable[[dict[str, typing.Any]], bool]]
|
||||
) -> collections.abc.Callable[[dict[str, typing.Any]], bool]:
|
||||
return items[0]
|
||||
@lark.visitors.v_args(inline=True)
|
||||
def func_call(self, func: lark.Token, *args: _T_Result) -> _T_Result:
|
||||
func_name = func.value.lower()
|
||||
# If unknown function, raise an error
|
||||
if func_name not in _FUNCTIONS_PARAMS_NUM:
|
||||
raise ValueError(f"Unknown function: {func.value}")
|
||||
|
||||
if len(args) != _FUNCTIONS_PARAMS_NUM[func_name]:
|
||||
raise ValueError(
|
||||
f"{func_name} function requires exactly {_FUNCTIONS_PARAMS_NUM[func_name]} arguments"
|
||||
)
|
||||
match func_name:
|
||||
case 'substringof':
|
||||
return lambda obj: str(args[1](obj)).find(str(args[0](obj))) != -1
|
||||
case 'startswith':
|
||||
return lambda obj: str(args[0](obj)).startswith(str(args[1](obj)))
|
||||
case 'endswith':
|
||||
return lambda obj: str(args[0](obj)).endswith(str(args[1](obj)))
|
||||
case 'indexof':
|
||||
return lambda obj: str(args[0](obj)).find(str(args[1](obj)))
|
||||
case 'concat':
|
||||
return lambda obj: str(args[0](obj)) + str(args[1](obj))
|
||||
case 'length':
|
||||
return lambda obj: len(str(args[0](obj)))
|
||||
case 'tolower':
|
||||
return lambda obj: str(args[0](obj)).lower()
|
||||
case 'toupper':
|
||||
return lambda obj: str(args[0](obj)).upper()
|
||||
case 'year':
|
||||
return lambda obj: str(args[0](obj)).split('-')[0] if isinstance(args[0](obj), str) else ''
|
||||
case 'month':
|
||||
return lambda obj: str(args[0](obj)).split('-')[1] if isinstance(args[0](obj), str) else ''
|
||||
case 'day':
|
||||
return lambda obj: str(args[0](obj)).split('-')[2] if isinstance(args[0](obj), str) else ''
|
||||
case _:
|
||||
# Will never reach this, as it has been already
|
||||
raise ValueError(f"Unknown function: {func.value}")
|
||||
|
||||
|
||||
# _odata_parser: typing.Final[Lark] = Lark(_ODATA_GRAMMAR, parser="lalr", transformer=ODataTransformer())
|
||||
|
||||
|
||||
def get_parser() -> Lark:
|
||||
def get_parser() -> lark.Lark:
|
||||
try:
|
||||
return _ODATA_PARSER_VAR.get()
|
||||
return _QUERY_PARSER_VAR.get()
|
||||
except LookupError:
|
||||
parser = Lark(_ODATA_GRAMMAR, parser="lalr", transformer=ODataTransformer())
|
||||
_ODATA_PARSER_VAR.set(parser)
|
||||
parser = lark.Lark(_QUERY_GRAMMAR, parser="lalr", transformer=QueryTransformer())
|
||||
_QUERY_PARSER_VAR.set(parser)
|
||||
return parser
|
||||
|
||||
|
||||
# filter
|
||||
def exec_filter(data: list[dict[str, typing.Any]], query: str) -> typing.Iterable[dict[str, typing.Any]]:
|
||||
try:
|
||||
filter_func = typing.cast(
|
||||
collections.abc.Callable[[dict[str, typing.Any]], bool], get_parser().parse(query)
|
||||
)
|
||||
filter_func = typing.cast(_T_Result, get_parser().parse(query))
|
||||
return filter(filter_func, data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error al procesar la query OData: {e}")
|
||||
raise ValueError(f"Error processing query: {e}") from None
|
||||
|
@@ -53,10 +53,10 @@ class TestQueryFilter(unittest.TestCase):
|
||||
def test_grouped_expression_with_parentheses(self):
|
||||
query = "not (age gt 30 or name eq 'Bob')"
|
||||
result = list(exec_filter(self.data, query))
|
||||
# Esperamos solo a Alice y David, porque:
|
||||
# - Charlie tiene age > 30 → excluido
|
||||
# - Bob tiene name eq 'Bob' → excluido
|
||||
# - Alice y David tienen age == 30 y name != 'Bob' → incluidos
|
||||
# We expect:
|
||||
# - Charlie has age > 30 → excluded
|
||||
# - Bob has name eq 'Bob' → excluded
|
||||
# - Alice and David have age == 30 and name != 'Bob' → included
|
||||
expected = [
|
||||
{"name": "Alice", "age": 30},
|
||||
{"name": "David", "age": 30},
|
||||
@@ -73,3 +73,80 @@ class TestQueryFilter(unittest.TestCase):
|
||||
result = list(exec_filter(self.data, "length(name) eq 5"))
|
||||
expected = [{"name": "Alice", "age": 30}, {"name": "David", "age": 30}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_toupper_function(self):
|
||||
result = list(exec_filter(self.data, "toupper(name) eq 'ALICE'"))
|
||||
expected = [{"name": "Alice", "age": 30}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_tolower_function(self):
|
||||
result = list(exec_filter(self.data, "tolower(name) eq 'david'"))
|
||||
expected = [{"name": "David", "age": 30}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_concat_function(self):
|
||||
data = [
|
||||
{"first": "John", "last": "Doe"},
|
||||
{"first": "Jane", "last": "Smith"},
|
||||
]
|
||||
result = list(exec_filter(data, "concat(first,last) eq 'JohnDoe'"))
|
||||
expected = [{"first": "John", "last": "Doe"}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_indexof_function(self):
|
||||
result = list(exec_filter(self.data, "indexof(name,'a') ge 0"))
|
||||
expected = [
|
||||
{"name": "Charlie", "age": 35},
|
||||
{"name": "David", "age": 30},
|
||||
]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_substringof_function(self):
|
||||
result = list(exec_filter(self.data, "substringof('li',name)"))
|
||||
expected = [{"name": "Alice", "age": 30}, {"name": "Charlie", "age": 35}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_year_function(self):
|
||||
data = [
|
||||
{"dob": "1990-05-12"},
|
||||
{"dob": "1985-11-30"},
|
||||
]
|
||||
result = list(exec_filter(data, "year(dob) eq '1990'"))
|
||||
expected = [{"dob": "1990-05-12"}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_month_function(self):
|
||||
data = [
|
||||
{"dob": "1990-05-12"},
|
||||
{"dob": "1985-11-30"},
|
||||
]
|
||||
result = list(exec_filter(data, "month(dob) eq '11'"))
|
||||
expected = [{"dob": "1985-11-30"}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_day_function(self):
|
||||
data = [
|
||||
{"dob": "1990-05-12"},
|
||||
{"dob": "1985-11-30"},
|
||||
]
|
||||
result = list(exec_filter(data, "day(dob) eq '12'"))
|
||||
expected = [{"dob": "1990-05-12"}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_nested_field_access(self):
|
||||
data = [
|
||||
{"user": {"name": "Alice"}},
|
||||
{"user": {"name": "Bob"}},
|
||||
]
|
||||
result = list(exec_filter(data, "user.name eq 'Bob'"))
|
||||
expected = [{"user": {"name": "Bob"}}]
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_not_with_parentheses(self):
|
||||
result = list(exec_filter(self.data, "not (name eq 'Alice')"))
|
||||
expected = [
|
||||
{"name": "Bob", "age": 25},
|
||||
{"name": "Charlie", "age": 35},
|
||||
{"name": "David", "age": 30},
|
||||
]
|
||||
self.assertEqual(result, expected)
|
||||
|
Reference in New Issue
Block a user