Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions reframe/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,27 +319,30 @@ class RegressionTestDict(UserDict):

:user_dict: The user dictionary to be converted to a
:class:`RegressionTestDict`.
:value_type: The type of the dictionary values.
If :obj:`None`, the values can be of any type.
:protocol: The protocol to be used to handle missing keys.

.. versionadded:: 4.10

'''
def __init__(self, user_dict: dict = None, protocol: str = None):
def __init__(self, user_dict: dict = None,
value_type: type = None, protocol: str = None):
super().__init__(user_dict or {})
self._index = self.data.pop('$index', None)
self._protocol = protocol
self.validate()
self.validate(value_type)

def validate(self):
def validate(self, value_type):
'''Validate the type of the dictionary.

:raises TypeError: if the dictionary does not match the expected type
'''
ref3_type = typ.Tuple[~Deferrable, ~Deferrable, ~Deferrable]
ref4_type = typ.Tuple[~Deferrable, ~Deferrable,
~Deferrable, ~Deferrable]
reftuple_type = ref3_type | ref4_type | XfailRef
value_type = value_type or object
if self._index is None:
dict_type = typ.Dict[str, typ.Dict[str, reftuple_type]]
dict_type = typ.Dict[~Deferrable, value_type]
else:
dict_type = typ.Dict[str, reftuple_type]
dict_type = value_type
for _ in self._index:
dict_type = typ.Dict[~Deferrable, dict_type]

Expand Down Expand Up @@ -430,6 +433,13 @@ class _ReferenceDict(RegressionTestDict):
An external references file can be specified with the special key
``$ref``.
'''
# Reference dictionary value type
_REF3_TYPE = typ.Tuple[~Deferrable, ~Deferrable, ~Deferrable]
_REF4_TYPE = typ.Tuple[~Deferrable, ~Deferrable,
~Deferrable, ~Deferrable]
_REFTUPLE_TYPE = _REF3_TYPE | _REF4_TYPE | XfailRef
_VALUE_TYPE = typ.Dict[str, _REFTUPLE_TYPE]

def __init__(self, user_dict=None, *, test):
user_dict = user_dict or {}
self.__ref_file = user_dict.pop('$ref', None)
Expand All @@ -438,7 +448,7 @@ def __init__(self, user_dict=None, *, test):
# external reference file
user_dict = {}

super().__init__(user_dict, protocol='ref')
super().__init__(user_dict, value_type=self._VALUE_TYPE, protocol='ref')
if self.__ref_file and test is not None:
self.resolve_external_references(test)

Expand All @@ -460,7 +470,7 @@ def resolve_external_references(self, test):
self._index = user_dict.pop('$index', None)
self.data = user_dict
try:
self.validate()
self.validate(self._VALUE_TYPE)
except TypeError as err:
raise ReferenceParseError(f'{self.__ref_file}: {err}') from err

Expand Down
35 changes: 35 additions & 0 deletions unittests/resources/checks_unlisted/indexed_refs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2016-2026 Swiss National Supercomputing Centre (CSCS/ETH Zurich)
# ReFrame Project Developers. See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: BSD-3-Clause

import reframe as rfm
import reframe.utility.sanity as sn
from reframe.core.builtins import (parameter,
sanity_function,
performance_function)


@rfm.simple_test
class IndexedRefsTest(rfm.RunOnlyRegressionTest):
valid_systems = ['*']
valid_prog_environs = ['*']
p = parameter(['foo', 'bar'])
executable = 'echo "throughput: 100"'
reference = {
'$index': ('p',),
'foo': {
'throughput': (100, None, None, 'MB/s')
},
'bar': {
'throughput': (200, None, None, 'MB/s')
}
}

@sanity_function
def validate(self):
return sn.assert_found(r'throughput', self.stdout)

@performance_function('MB/s')
def throughput(self):
return sn.extractsingle(r'throughput: (\S+)', self.stdout, 1, float)
8 changes: 8 additions & 0 deletions unittests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ def test_load_fixtures(loader):
assert 5 == len(tests)


def test_load_indexed_refs(loader):
# Assert that tests with indexed references are loaded without errors
tests = loader.load_from_file(
'unittests/resources/checks_unlisted/indexed_refs.py'
)
assert 2 == len(tests)


def test_existing_module_name(loader, tmp_path):
test_file = tmp_path / 'os.py'
shutil.copyfile('unittests/resources/checks/emptycheck.py', test_file)
Expand Down
24 changes: 14 additions & 10 deletions unittests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,13 +2067,17 @@ def test_reference_external_custom_prefix(make_perftest, make_path,
def test_regressiondict_custom_protocol(dummy_gpu_exec_ctx):
class _MyTest(rfm.RunOnlyRegressionTest):
x = variable(int, value=1)
foo = variable(rfm.RegressionTestDictType(protocol='foo'), value={
'$index': ('$dev.gpu.model', 'x'),
'v100': {
2: {'value1': (1.4, -0.1, 0.1, None)},
4: {'value1': (2.8, -0.1, 0.1, None)},
}
}, allow_implicit=True)
foo = variable(
rfm.RegressionTestDictType(protocol='foo', value_type=str),
value={
'$index': ('$dev.gpu.model', 'x'),
'v100': {
2: 'value1',
4: 'value2',
}
},
allow_implicit=True
)

def __foo_missing_dev_gpu_model__(self, data, key):
# Map p100 to v100 reference values
Expand All @@ -2089,13 +2093,13 @@ def __foo_missing_x__(self, data, key):
test = _MyTest()
test.x = 2
test.setup(*dummy_gpu_exec_ctx)
assert test.foo[test] == {'value1': (1.4, -0.1, 0.1, None)}
assert test.foo[test] == 'value1'

test.x = 4
assert test.foo[test] == {'value1': (2.8, -0.1, 0.1, None)}
assert test.foo[test] == 'value2'

test.x = 6
assert test.foo[test] == {'value1': (2.8, -0.1, 0.1, None)}
assert test.foo[test] == 'value2'

test.x = 1
with pytest.raises(KeyError):
Expand Down
Loading