Skip to content
120 changes: 105 additions & 15 deletions pyk/src/pyk/kast/outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .att import EMPTY_ATT, Atts, Format, KAst, KAtt, WithKAtt
from .inner import (
KApply,
KAs,
KInner,
KLabel,
KRewrite,
Expand Down Expand Up @@ -1327,6 +1328,8 @@ def sort(self, kast: KInner) -> KSort | None:
match kast:
case KToken(_, sort) | KVariable(_, sort):
return sort
case KAs(alias=KVariable(sort=sort)):
return sort
case KRewrite(lhs, rhs):
lhs_sort = self.sort(lhs)
rhs_sort = self.sort(rhs)
Expand All @@ -1336,8 +1339,11 @@ def sort(self, kast: KInner) -> KSort | None:
case KSequence(_):
return KSort('K')
case KApply(label, _):
sort, _ = self.resolve_sorts(label)
return sort
try:
sort, _ = self.resolve_sorts(label)
return sort
except (KeyError, ValueError):
return None
case _:
return None

Expand All @@ -1354,7 +1360,13 @@ def resolve_sorts(self, label: KLabel) -> tuple[KSort, tuple[KSort, ...]]:
sorts = dict(zip(prod.params, label.params, strict=True))

def resolve(sort: KSort) -> KSort:
return sorts.get(sort, sort)
# Direct match: sort IS one of the sort parameters.
if sort in sorts:
return sorts[sort]
# Recursive substitution: sort params may appear nested (e.g. MInt{Width} → MInt{8}).
if sort.params:
return KSort(sort.name, tuple(resolve(p) for p in sort.params))
return sort

return resolve(prod.sort), tuple(resolve(sort) for sort in prod.argument_sorts)

Expand Down Expand Up @@ -1483,28 +1495,86 @@ def transform(
# Best-effort addition of sort parameters to klabels, context insensitive
def add_sort_params(self, kast: KInner) -> KInner:
"""Return a given term with the sort parameters on the `KLabel` filled in (which may be missing because of how the frontend works), best effort."""
# ML predicate labels whose result sort (Sort2) is context-dependent and not inferable
# from the arguments alone. When Sort1 can be determined but Sort2 cannot, we fill Sort2
# with the sentinel KSort('#SortParam') so that downstream Kore emission can introduce a
# universally-quantified sort variable (Q0) in the axiom.
_ML_PRED_RESULT_SORT_PARAM = KSort('#SortParam') # noqa: N806
_ML_PRED_LABELS = frozenset({'#Equals', '#Ceil', '#Floor', '#In'}) # noqa: N806

Comment on lines 1496 to +1504
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear how this algorithm ends up handling cases with multiple sort parameters, since it uses just one sentinel value.

def _unify_sort_params(parametric: KSort, actual: KSort, params: frozenset[KSort]) -> dict[KSort, KSort]:
"""Match parametric sort against actual, extracting bindings for sort params.

Handles both direct (parametric IS a sort param) and nested
(parametric = MInt{Width}, actual = MInt{8}) cases.
Returns empty dict when no bindings could be extracted (no match).
"""
if parametric in params:
return {parametric: actual}
if parametric.name != actual.name or len(parametric.params) != len(actual.params):
return {}
result: dict[KSort, KSort] = {}
for p_sub, a_sub in zip(parametric.params, actual.params, strict=True):
sub_bindings = _unify_sort_params(p_sub, a_sub, params)
for k, v in sub_bindings.items():
if k in result and result[k] != v:
return {} # Conflicting bindings
result[k] = v
return result

def _merge_binding(sort_dict: dict[KSort, KSort], k: KSort, v: KSort) -> bool:
"""Merge one binding into sort_dict in place. Returns False on irreconcilable conflict."""
if k in sort_dict:
existing = sort_dict[k]
if existing == _ML_PRED_RESULT_SORT_PARAM:
sort_dict[k] = v # Concrete sort overrides sentinel.
elif existing != v:
lub = self.least_common_supersort(existing, v)
if lub is None:
_LOGGER.warning(f'Failed to add sort parameter, sort mismatch: {(k, existing, v)}')
return False
sort_dict[k] = lub
else:
sort_dict[k] = v
return True

def _add_sort_params(_k: KInner) -> KInner:
if type(_k) is KApply:
prod = self.symbols[_k.label.name]
if len(_k.label.params) == 0 and len(prod.params) > 0:
param_set = frozenset(prod.params)
sort_dict: dict[KSort, KSort] = {}
for psort, asort in zip(prod.argument_sorts, map(self.sort, _k.args), strict=True):
if asort == _ML_PRED_RESULT_SORT_PARAM:
# #SortParam is the sentinel for an ML pred result sort that cannot be
# inferred bottom-up (e.g. #Equals result sort depends on outer context).
# It propagates upward into ML connectives (#And, #Or, #Not) as a
# placeholder for the axiom sort variable Q0, but a concrete sort takes
# precedence when one is available.
bindings = _unify_sort_params(psort, asort, param_set)
for k, v in bindings.items():
if k not in sort_dict: # sentinel fills only empty slots
sort_dict[k] = v
continue
if asort is None:
_LOGGER.warning(
f'Failed to add sort parameter, unable to determine sort for argument in production: {(prod, psort, asort)}'
)
return _k
if psort in prod.params:
if psort in sort_dict and sort_dict[psort] != asort:
_LOGGER.warning(
f'Failed to add sort parameter, sort mismatch between different occurances of sort parameter: {(prod, psort, sort_dict[psort], asort)}'
)
# Unify psort with asort to extract bindings for sort params.
# Handles both direct (psort=Width) and nested (psort=MInt{Width}) cases.
bindings = _unify_sort_params(psort, asort, param_set)
for k, v in bindings.items():
if not _merge_binding(sort_dict, k, v):
return _k
elif psort not in sort_dict:
sort_dict[psort] = asort
if all(p in sort_dict for p in prod.params):
return _k.let(label=KLabel(_k.label.name, [sort_dict[p] for p in prod.params]))
# ML predicates have a context-dependent result sort (Sort2) that cannot be
# inferred from arguments. Fill it with the sentinel so that krule_to_kore can
# introduce a universally-quantified sort variable for the axiom.
if _k.label.name in _ML_PRED_LABELS:
filled = {p: sort_dict.get(p, _ML_PRED_RESULT_SORT_PARAM) for p in prod.params}
return _k.let(label=KLabel(_k.label.name, [filled[p] for p in prod.params]))
return _k

return bottom_up(_add_sort_params, kast)
Expand All @@ -1515,15 +1585,35 @@ def add_cell_map_items(self, kast: KInner) -> KInner:
# syntax AccountCellMap [cellCollection, hook(MAP.Map)]
# syntax AccountCellMap ::= AccountCellMap AccountCellMap [assoc, avoid, cellCollection, comm, element(AccountCellMapItem), function, hook(MAP.concat), unit(.AccountCellMap), wrapElement(<account>)]

cell_wrappers = {}
# Maps cell label -> (element_constructor, cell_map_sort).
# Wrapping is correct only when the parent production expects the cell MAP sort (e.g.
# EntryCellMap), not when it expects the individual cell element sort (e.g. EntryCell).
# For example, EntryCellMapKey(<entry>(...)) takes EntryCell — the <entry> must NOT be
# wrapped, whereas _EntryCellMap_(<entry>(...), ...) expects EntryCellMap — wrapping is needed.
cell_wrappers: dict[str, tuple[str, KSort]] = {}
for ccp in self.cell_collection_productions:
if Atts.ELEMENT in ccp.att and Atts.WRAP_ELEMENT in ccp.att:
cell_wrappers[ccp.att[Atts.WRAP_ELEMENT]] = ccp.att[Atts.ELEMENT]
cell_label = ccp.att[Atts.WRAP_ELEMENT]
element_ctor = ccp.att[Atts.ELEMENT]
if element_ctor in self.symbols:
cell_wrappers[cell_label] = (element_ctor, self.symbols[element_ctor].sort)

def _wrap_elements(_k: KInner) -> KInner:
if type(_k) is KApply and _k.label.name in cell_wrappers:
return KApply(cell_wrappers[_k.label.name], [_k.args[0], _k])
return _k
if not isinstance(_k, KApply) or _k.label.name not in self.symbols:
return _k
prod = self.symbols[_k.label.name]
arg_sorts = prod.argument_sorts
if not arg_sorts or len(arg_sorts) != _k.arity:
return _k
new_args: list[KInner] = list(_k.args)
changed = False
for i, (arg_sort, arg) in enumerate(zip(arg_sorts, _k.args, strict=True)):
if isinstance(arg, KApply) and arg.label.name in cell_wrappers:
element_ctor, cell_map_sort = cell_wrappers[arg.label.name]
if arg_sort == cell_map_sort:
new_args[i] = KApply(element_ctor, [arg.args[0], arg])
changed = True
return _k.let(args=new_args) if changed else _k

# To ensure we don't get duplicate wrappers.
_kast = self.remove_cell_map_items(kast)
Expand Down
Loading
Loading