diff --git a/pyNN/parameters.py b/pyNN/parameters.py index f6d652d4..d7085f78 100644 --- a/pyNN/parameters.py +++ b/pyNN/parameters.py @@ -350,14 +350,14 @@ def _setitem_value(self, name, value): valid_parameter_names=self.schema.keys()) if issubclass(expected_dtype, ArrayParameter) and isinstance(value, Sized): if len(value) == 0: - value = ArrayParameter([]) + value = expected_dtype([]) elif not isinstance(value[0], ArrayParameter): # may be a more generic way to do it, but for now this special-casing # seems like the most robust approach if isinstance(value[0], Sized): # e.g. list of tuples - value = type(value)([ArrayParameter(x) for x in value]) + value = type(value)([expected_dtype(x) for x in value]) else: - value = ArrayParameter(value) + value = expected_dtype(value) try: self._parameters[name] = LazyArray(value, shape=self._shape, dtype=expected_dtype) diff --git a/test/unittests/test_parameters.py b/test/unittests/test_parameters.py index 18b9b3aa..135a7af1 100644 --- a/test/unittests/test_parameters.py +++ b/test/unittests/test_parameters.py @@ -482,6 +482,38 @@ def test_create_with_list_of_lists(self): assert_array_equal(ps['a'], np.array( [Sequence([1, 2, 3]), Sequence([4, 5, 6])], dtype=Sequence)) + def test_create_with_plain_list_produces_sequence(self): + schema = {'a': Sequence} + ps = ParameterSpace({'a': [1, 2, 3]}, schema, shape=(2,)) + ps.evaluate() + result = ps['a'] + assert type(result[0]) == Sequence + assert_array_equal(result, np.array([Sequence([1, 2, 3]), Sequence([1, 2, 3])], dtype=Sequence)) + + def test_create_with_numpy_array_produces_sequence(self): + schema = {'a': Sequence} + ps = ParameterSpace({'a': np.array([1, 2, 3])}, schema, shape=(2,)) + ps.evaluate() + result = ps['a'] + assert type(result[0]) == Sequence + assert_array_equal(result, np.array([Sequence([1, 2, 3]), Sequence([1, 2, 3])], dtype=Sequence)) + + def test_create_with_list_of_numpy_arrays_produces_sequences(self): + schema = {'a': Sequence} + ps = ParameterSpace({'a': [np.array([1, 2]), np.array([3, 4])]}, schema, shape=(2,)) + ps.evaluate() + result = ps['a'] + assert type(result[0]) == Sequence + assert type(result[1]) == Sequence + assert_array_equal(result, np.array([Sequence([1, 2]), Sequence([3, 4])], dtype=Sequence)) + + def test_create_with_empty_list_produces_sequence(self): + schema = {'a': Sequence} + ps = ParameterSpace({'a': []}, schema, shape=(2,)) + ps.evaluate() + result = ps['a'] + assert type(result[0]) == Sequence + def test_keys(self): ps = ParameterSpace({'a': [2, 3, 5, 8, 13], 'b': 7, 'c': lambda i: 3 * i + 2}, shape=(5,)) assert list(ps.keys()) == ["a", "b", "c"]