diff --git a/pyk/src/pyk/kast/outer.py b/pyk/src/pyk/kast/outer.py index 774a004e5c..563eec0c52 100644 --- a/pyk/src/pyk/kast/outer.py +++ b/pyk/src/pyk/kast/outer.py @@ -17,6 +17,7 @@ from .att import EMPTY_ATT, Atts, Format, KAst, KAtt, WithKAtt from .inner import ( KApply, + KAs, KInner, KLabel, KRewrite, @@ -1327,6 +1328,8 @@ def sort(self, kast: KInner) -> KSort | None: match kast: case KToken(_, sort) | KVariable(_, sort): return sort + case KAs(alias=KVariable(sort=sort)): + return sort case KRewrite(lhs, rhs): lhs_sort = self.sort(lhs) rhs_sort = self.sort(rhs) @@ -1336,8 +1339,11 @@ def sort(self, kast: KInner) -> KSort | None: case KSequence(_): return KSort('K') case KApply(label, _): - sort, _ = self.resolve_sorts(label) - return sort + try: + sort, _ = self.resolve_sorts(label) + return sort + except (KeyError, ValueError): + return None case _: return None @@ -1354,7 +1360,13 @@ def resolve_sorts(self, label: KLabel) -> tuple[KSort, tuple[KSort, ...]]: sorts = dict(zip(prod.params, label.params, strict=True)) def resolve(sort: KSort) -> KSort: - return sorts.get(sort, sort) + # Direct match: sort IS one of the sort parameters. + if sort in sorts: + return sorts[sort] + # Recursive substitution: sort params may appear nested (e.g. MInt{Width} → MInt{8}). + if sort.params: + return KSort(sort.name, tuple(resolve(p) for p in sort.params)) + return sort return resolve(prod.sort), tuple(resolve(sort) for sort in prod.argument_sorts) @@ -1483,28 +1495,98 @@ def transform( # Best-effort addition of sort parameters to klabels, context insensitive def add_sort_params(self, kast: KInner) -> KInner: """Return a given term with the sort parameters on the `KLabel` filled in (which may be missing because of how the frontend works), best effort.""" + # ML predicate labels whose result sort (Sort2) is context-dependent and not inferable + # from the arguments alone. When Sort1 can be determined but Sort2 cannot, we fill Sort2 + # with the sentinel KSort('#SortParam') so that downstream Kore emission can introduce a + # universally-quantified sort variable (Q0) in the axiom. + _ML_PRED_RESULT_SORT_PARAM = KSort('#SortParam') # noqa: N806 + _ML_PRED_LABELS = frozenset({'#Equals', '#Ceil', '#Floor', '#In'}) # noqa: N806 + + def _unify_sort_params(parametric: KSort, actual: KSort, params: frozenset[KSort]) -> dict[KSort, KSort]: + """Match parametric sort against actual, extracting bindings for sort params. + + Handles both direct (parametric IS a sort param) and nested + (parametric = MInt{Width}, actual = MInt{8}) cases. + Returns empty dict when no bindings could be extracted (no match). + """ + if parametric in params: + return {parametric: actual} + if parametric.name != actual.name or len(parametric.params) != len(actual.params): + return {} + result: dict[KSort, KSort] = {} + for p_sub, a_sub in zip(parametric.params, actual.params, strict=True): + sub_bindings = _unify_sort_params(p_sub, a_sub, params) + for k, v in sub_bindings.items(): + if k in result and result[k] != v: + return {} # Conflicting bindings + result[k] = v + return result + + def _merge_binding(sort_dict: dict[KSort, KSort], k: KSort, v: KSort) -> bool: + """Merge one binding into sort_dict in place. Returns False on irreconcilable conflict.""" + if k in sort_dict: + existing = sort_dict[k] + if existing == _ML_PRED_RESULT_SORT_PARAM: + sort_dict[k] = v # Concrete sort overrides sentinel. + elif existing != v: + lub = self.least_common_supersort(existing, v) + if lub is None: + _LOGGER.warning(f'Failed to add sort parameter, sort mismatch: {(k, existing, v)}') + return False + sort_dict[k] = lub + else: + sort_dict[k] = v + return True def _add_sort_params(_k: KInner) -> KInner: if type(_k) is KApply: prod = self.symbols[_k.label.name] if len(_k.label.params) == 0 and len(prod.params) > 0: + param_set = frozenset(prod.params) sort_dict: dict[KSort, KSort] = {} for psort, asort in zip(prod.argument_sorts, map(self.sort, _k.args), strict=True): + if asort == _ML_PRED_RESULT_SORT_PARAM: + # #SortParam is the sentinel for an ML pred result sort that cannot be + # inferred bottom-up (e.g. #Equals result sort depends on outer context). + # It propagates upward into ML connectives (#And, #Or, #Not) as a + # placeholder for the axiom sort variable Q0, but a concrete sort takes + # precedence when one is available. + bindings = _unify_sort_params(psort, asort, param_set) + for k, v in bindings.items(): + if k not in sort_dict: # sentinel fills only empty slots + sort_dict[k] = v + continue if asort is None: _LOGGER.warning( f'Failed to add sort parameter, unable to determine sort for argument in production: {(prod, psort, asort)}' ) return _k - if psort in prod.params: - if psort in sort_dict and sort_dict[psort] != asort: - _LOGGER.warning( - f'Failed to add sort parameter, sort mismatch between different occurances of sort parameter: {(prod, psort, sort_dict[psort], asort)}' - ) + # Unify psort with asort to extract bindings for sort params. + # Handles both direct (psort=Width) and nested (psort=MInt{Width}) cases. + bindings = _unify_sort_params(psort, asort, param_set) + for k, v in bindings.items(): + if not _merge_binding(sort_dict, k, v): return _k - elif psort not in sort_dict: - sort_dict[psort] = asort if all(p in sort_dict for p in prod.params): return _k.let(label=KLabel(_k.label.name, [sort_dict[p] for p in prod.params])) + # ML predicates have a context-dependent result sort (Sort2) that cannot be + # inferred from arguments. Fill it with the sentinel so that krule_to_kore can + # introduce a universally-quantified sort variable for the axiom. + if _k.label.name in _ML_PRED_LABELS: + unbound = [p for p in prod.params if p not in sort_dict] + # The single sentinel KSort('#SortParam') is only unambiguous when at most + # one parameter is unresolvable bottom-up. All current ML predicates + # (#Equals, #Ceil, #Floor, #In) have exactly two sort params {Sort1, + # Sort2}: Sort1 is always determined by the arguments, Sort2 (the result + # sort) is the one remaining unbound param. If this assertion ever fires, + # the sentinel scheme needs to be replaced with unique fresh params, as + # Java does with #SortParam{Q0}, #SortParam{Q1}, .... + assert len(unbound) <= 1, ( + f'Expected at most one unbound sort parameter for {_k.label.name!r}, ' + f'got {len(unbound)}: {unbound}' + ) + filled = {p: sort_dict.get(p, _ML_PRED_RESULT_SORT_PARAM) for p in prod.params} + return _k.let(label=KLabel(_k.label.name, [filled[p] for p in prod.params])) return _k return bottom_up(_add_sort_params, kast) @@ -1515,15 +1597,35 @@ def add_cell_map_items(self, kast: KInner) -> KInner: # syntax AccountCellMap [cellCollection, hook(MAP.Map)] # syntax AccountCellMap ::= AccountCellMap AccountCellMap [assoc, avoid, cellCollection, comm, element(AccountCellMapItem), function, hook(MAP.concat), unit(.AccountCellMap), wrapElement()] - cell_wrappers = {} + # Maps cell label -> (element_constructor, cell_map_sort). + # Wrapping is correct only when the parent production expects the cell MAP sort (e.g. + # EntryCellMap), not when it expects the individual cell element sort (e.g. EntryCell). + # For example, EntryCellMapKey((...)) takes EntryCell — the must NOT be + # wrapped, whereas _EntryCellMap_((...), ...) expects EntryCellMap — wrapping is needed. + cell_wrappers: dict[str, tuple[str, KSort]] = {} for ccp in self.cell_collection_productions: if Atts.ELEMENT in ccp.att and Atts.WRAP_ELEMENT in ccp.att: - cell_wrappers[ccp.att[Atts.WRAP_ELEMENT]] = ccp.att[Atts.ELEMENT] + cell_label = ccp.att[Atts.WRAP_ELEMENT] + element_ctor = ccp.att[Atts.ELEMENT] + if element_ctor in self.symbols: + cell_wrappers[cell_label] = (element_ctor, self.symbols[element_ctor].sort) def _wrap_elements(_k: KInner) -> KInner: - if type(_k) is KApply and _k.label.name in cell_wrappers: - return KApply(cell_wrappers[_k.label.name], [_k.args[0], _k]) - return _k + if not isinstance(_k, KApply) or _k.label.name not in self.symbols: + return _k + prod = self.symbols[_k.label.name] + arg_sorts = prod.argument_sorts + if not arg_sorts or len(arg_sorts) != _k.arity: + return _k + new_args: list[KInner] = list(_k.args) + changed = False + for i, (arg_sort, arg) in enumerate(zip(arg_sorts, _k.args, strict=True)): + if isinstance(arg, KApply) and arg.label.name in cell_wrappers: + element_ctor, cell_map_sort = cell_wrappers[arg.label.name] + if arg_sort == cell_map_sort: + new_args[i] = KApply(element_ctor, [arg.args[0], arg]) + changed = True + return _k.let(args=new_args) if changed else _k # To ensure we don't get duplicate wrappers. _kast = self.remove_cell_map_items(kast) diff --git a/pyk/src/tests/unit/kast/test_definition.py b/pyk/src/tests/unit/kast/test_definition.py new file mode 100644 index 0000000000..da397e3916 --- /dev/null +++ b/pyk/src/tests/unit/kast/test_definition.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from pyk.kast.att import Atts, KAtt +from pyk.kast.inner import KApply, KAs, KLabel, KSequence, KSort, KToken, KVariable +from pyk.kast.outer import KDefinition, KFlatModule, KNonTerminal, KProduction, KTerminal + +if TYPE_CHECKING: + from typing import Final + + from pyk.kast.inner import KInner + + +# --------------------------------------------------------------------------- +# Minimal test definition +# +# bar: syntax N ::= bar(N) -- result sort is the param directly +# foo: syntax MInt{N} ::= foo(MInt{N}) -- result/arg sorts nest the param +# #Equals: syntax S2 ::= #Equals{S1,S2}(S1, S1) -- ML pred, result sort context-dependent +# +# Cell map fragment: +# AccountCellMap ::= AccountCellMap AccountCellMap [cellCollection, element(AccountCellMapItem), wrapElement()] +# AccountCellMap ::= AccountCellMapItem(Int, AccountCell) +# AccountCell ::= (Int, Int) +# AccountCell ::= getEntry(AccountCell) -- takes element sort, NOT map sort +# --------------------------------------------------------------------------- + +INT: Final = KSort('Int') +N: Final = KSort('N') +S1: Final = KSort('S1') +S2: Final = KSort('S2') +MINT_N: Final = KSort('MInt', (N,)) +MINT_INT: Final = KSort('MInt', (INT,)) +SORT_PARAM: Final = KSort('#SortParam') +ACCOUNT_CELL_MAP: Final = KSort('AccountCellMap') +ACCOUNT_CELL: Final = KSort('AccountCell') + +_BAR_PROD: Final = KProduction( + sort=N, + items=[KTerminal('bar'), KTerminal('('), KNonTerminal(N), KTerminal(')')], + params=[N], + klabel='bar', +) + +_FOO_PROD: Final = KProduction( + sort=MINT_N, + items=[KTerminal('foo'), KTerminal('('), KNonTerminal(MINT_N), KTerminal(')')], + params=[N], + klabel='foo', +) + +_EQUALS_PROD: Final = KProduction( + sort=S2, + items=[KNonTerminal(S1), KNonTerminal(S1)], + params=[S1, S2], + klabel='#Equals', +) + +_ACCT_MAP_CONCAT: Final = KProduction( + sort=ACCOUNT_CELL_MAP, + items=[KNonTerminal(ACCOUNT_CELL_MAP), KNonTerminal(ACCOUNT_CELL_MAP)], + klabel='_AccountCellMap_', + att=KAtt(entries=[Atts.CELL_COLLECTION(None), Atts.ELEMENT('AccountCellMapItem'), Atts.WRAP_ELEMENT('')]), +) + +_ACCT_MAP_ITEM: Final = KProduction( + sort=ACCOUNT_CELL_MAP, + items=[ + KTerminal('AccountCellMapItem'), + KTerminal('('), + KNonTerminal(INT), + KTerminal(','), + KNonTerminal(ACCOUNT_CELL), + KTerminal(')'), + ], + klabel='AccountCellMapItem', +) + +_ACCOUNT_CELL: Final = KProduction( + sort=ACCOUNT_CELL, + items=[ + KTerminal(''), + KTerminal('('), + KNonTerminal(INT), + KTerminal(','), + KNonTerminal(INT), + KTerminal(')'), + ], + klabel='', +) + +_GET_ENTRY: Final = KProduction( + sort=ACCOUNT_CELL, + items=[KTerminal('getEntry'), KTerminal('('), KNonTerminal(ACCOUNT_CELL), KTerminal(')')], + klabel='getEntry', +) + +DEFN: Final = KDefinition( + 'TEST', + [ + KFlatModule( + 'TEST', [_BAR_PROD, _FOO_PROD, _EQUALS_PROD, _ACCT_MAP_CONCAT, _ACCT_MAP_ITEM, _ACCOUNT_CELL, _GET_ENTRY] + ) + ], +) + + +# --------------------------------------------------------------------------- +# KDefinition.sort +# --------------------------------------------------------------------------- + +SORT_DATA: Final = ( + # Basic leaf terms + ('ktoken', KToken('42', INT), INT), + ('kvariable_with_sort', KVariable('X', sort=INT), INT), + ('ksequence', KSequence([]), KSort('K')), + # KApply: result sort substituted directly from param + ('kapply_direct_result', KApply(KLabel('bar', [INT]), [KVariable('X', sort=INT)]), INT), + # KApply: result sort nests the param (MInt{N} with N→Int → MInt{Int}) + ('kapply_nested_result', KApply(KLabel('foo', [INT]), [KVariable('X', sort=MINT_INT)]), MINT_INT), + # KApply with unfilled sort params: sort() returns None rather than raising + ('kapply_unfilled_params', KApply(KLabel('foo'), [KVariable('X', sort=MINT_INT)]), None), + # KApply with unknown label: KeyError from symbols lookup → None + ('kapply_unknown_label', KApply(KLabel('nonexistent'), []), None), + # KAs: sort of the alias variable + ('kas_sorted_alias', KAs(KVariable('X', sort=MINT_INT), KVariable('Y', sort=MINT_INT)), MINT_INT), + # KAs whose alias has no sort annotation: returns None + ('kas_unsorted_alias', KAs(KVariable('X', sort=MINT_INT), KVariable('Y')), None), +) + + +@pytest.mark.parametrize( + 'test_id,term,expected', + SORT_DATA, + ids=[test_id for test_id, *_ in SORT_DATA], +) +def test_sort(test_id: str, term: KInner, expected: KSort | None) -> None: + assert DEFN.sort(term) == expected + + +# --------------------------------------------------------------------------- +# KDefinition.resolve_sorts +# --------------------------------------------------------------------------- + +RESOLVE_SORTS_DATA: Final = ( + # Direct substitution: result sort IS the param (N → Int) + ('direct_bar', KLabel('bar', [INT]), INT, (INT,)), + # Recursive substitution: result/arg sort nests the param (MInt{N} with N → Int → MInt{Int}) + ('nested_foo', KLabel('foo', [INT]), MINT_INT, (MINT_INT,)), +) + + +@pytest.mark.parametrize( + 'test_id,label,expected_result,expected_args', + RESOLVE_SORTS_DATA, + ids=[test_id for test_id, *_ in RESOLVE_SORTS_DATA], +) +def test_resolve_sorts(test_id: str, label: KLabel, expected_result: KSort, expected_args: tuple[KSort, ...]) -> None: + result, args = DEFN.resolve_sorts(label) + assert result == expected_result + assert args == expected_args + + +# --------------------------------------------------------------------------- +# KDefinition.add_sort_params +# --------------------------------------------------------------------------- + +ADD_SORT_PARAMS_DATA: Final = ( + # Label already has params filled: leave unchanged + ( + 'already_filled', + KApply(KLabel('bar', [INT]), [KVariable('X', sort=INT)]), + KApply(KLabel('bar', [INT]), [KVariable('X', sort=INT)]), + ), + # Direct sort param: psort IS the param (N ~ Int → N=Int) + ( + 'direct_param', + KApply(KLabel('bar'), [KVariable('X', sort=INT)]), + KApply(KLabel('bar', [INT]), [KVariable('X', sort=INT)]), + ), + # Nested sort param: psort = MInt{N}, asort = MInt{Int} → N=Int via unification + ( + 'nested_param', + KApply(KLabel('foo'), [KVariable('X', sort=MINT_INT)]), + KApply(KLabel('foo', [INT]), [KVariable('X', sort=MINT_INT)]), + ), + # ML pred: S1 inferred from args, S2 (result sort) filled with #SortParam sentinel + ( + 'ml_pred_sentinel', + KApply('#Equals', [KVariable('X', sort=INT), KVariable('Y', sort=INT)]), + KApply(KLabel('#Equals', [INT, SORT_PARAM]), [KVariable('X', sort=INT), KVariable('Y', sort=INT)]), + ), + # Unsortable argument (no sort annotation): cannot fill params, term returned unchanged + ( + 'unsortable_arg_unchanged', + KApply(KLabel('foo'), [KVariable('X')]), + KApply(KLabel('foo'), [KVariable('X')]), + ), +) + + +@pytest.mark.parametrize( + 'test_id,term,expected', + ADD_SORT_PARAMS_DATA, + ids=[test_id for test_id, *_ in ADD_SORT_PARAMS_DATA], +) +def test_add_sort_params(test_id: str, term: KInner, expected: KInner) -> None: + assert DEFN.add_sort_params(term) == expected + + +# --------------------------------------------------------------------------- +# KDefinition.add_cell_map_items +# --------------------------------------------------------------------------- + +_ACCT_1: Final = KApply('', [KVariable('X', sort=INT), KVariable('Y', sort=INT)]) +_ACCT_2: Final = KApply('', [KVariable('A', sort=INT), KVariable('B', sort=INT)]) + +ADD_CELL_MAP_ITEMS_DATA: Final = ( + # Parent expects AccountCellMap (the map sort) — children are wrapped in AccountCellMapItem. + ( + 'wraps_when_parent_expects_cell_map_sort', + KApply('_AccountCellMap_', [_ACCT_1, _ACCT_2]), + KApply( + '_AccountCellMap_', + [ + KApply('AccountCellMapItem', [KVariable('X', sort=INT), _ACCT_1]), + KApply('AccountCellMapItem', [KVariable('A', sort=INT), _ACCT_2]), + ], + ), + ), + # Parent expects AccountCell (the element sort) — the child must NOT be wrapped. + # Before the guard fix, _wrap_elements would incorrectly wrap here too. + ( + 'no_wrap_when_parent_expects_cell_element_sort', + KApply('getEntry', [_ACCT_1]), + KApply('getEntry', [_ACCT_1]), + ), +) + + +@pytest.mark.parametrize( + 'test_id,term,expected', + ADD_CELL_MAP_ITEMS_DATA, + ids=[test_id for test_id, *_ in ADD_CELL_MAP_ITEMS_DATA], +) +def test_add_cell_map_items(test_id: str, term: KInner, expected: KInner) -> None: + assert DEFN.add_cell_map_items(term) == expected