From 44ab862c9d9f216639208c76c4f5c1124081f595 Mon Sep 17 00:00:00 2001 From: Marvin Poul Date: Wed, 14 Feb 2024 18:01:32 +0100 Subject: [PATCH 1/9] Recreate structure fresh during get_primitive_cell This avoids a problem, when you pass standardize=True and the resulting cell would be larger than the original. Remove array copies and add warning The trouble is that when spglib returns the symmetrized cell, it may permute, add or remove atoms so that we cannot tell anymore which values from the original arrays we would need to copy or remove. --- src/structuretoolkit/analyse/symmetry.py | 33 ++++++++++++++++++------ tests/test_symmetry.py | 6 +++-- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/structuretoolkit/analyse/symmetry.py b/src/structuretoolkit/analyse/symmetry.py index 7af40bc40..f6fbb0e03 100644 --- a/src/structuretoolkit/analyse/symmetry.py +++ b/src/structuretoolkit/analyse/symmetry.py @@ -4,6 +4,7 @@ import ast import dataclasses import string +from logging import warning import numpy as np import spglib @@ -400,26 +401,42 @@ def get_primitive_cell( >>> symmetry = Symmetry(structure) >>> len(symmetry.get_primitive_cell()) == len(basis) True + + .. warning:: + Custom arrays defined in the base structures + :attr:`ase.atoms.Atoms.arrays` are not copied to the new structure! """ + if not all(self._structure.pbc): + raise ValueError("Can only symmetrize periodic structures.") ret = spglib.standardize_cell( self._get_spglib_cell(use_elements=use_elements, use_magmoms=use_magmoms), to_primitive=not standardize, ) if ret is None: - raise SymmetryError(spglib.error.get_error_message()) - cell, positions, indices = ret - positions = (cell.T @ positions.T).T - new_structure = self._structure.copy() - new_structure.cell = cell - new_structure = new_structure[: len(indices)] + raise SymmetryError(spglib.spglib.spglib_error.message) + cell, scaled_positions, indices = ret indices_dict = { v: k for k, v in structuretoolkit.common.helper.get_species_indices_dict( structure=self._structure ).items() } - new_structure.symbols = [indices_dict[i] for i in indices] - new_structure.positions = positions + symbols = [indices_dict[i] for i in indices] + arrays = { + k: self._structure.arrays[k] + for k in self._structure.arrays + if k not in ("numbers", "positions") + } + new_structure = type(self._structure)( + symbols=symbols, + scaled_positions=scaled_positions, + cell=cell, + pbc=[True, True, True], + ) + keys = set(arrays) - {"numbers", "positions"} + if len(keys) > 0: + warning(f"Custom arrays {keys} do not carry over to new structure!") + return new_structure def get_ir_reciprocal_mesh( diff --git a/tests/test_symmetry.py b/tests/test_symmetry.py index e32ae6aff..3fcca286b 100644 --- a/tests/test_symmetry.py +++ b/tests/test_symmetry.py @@ -180,7 +180,9 @@ def test_get_ir_reciprocal_mesh(self): def test_get_primitive_cell(self): cell = 2.2 * np.identity(3) - basis = Atoms("AlFe", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell) + basis = Atoms( + "AlFe", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True + ) structure = basis.repeat([2, 2, 2]) sym = stk.analyse.get_symmetry(structure=structure) self.assertEqual(len(basis), len(sym.get_primitive_cell(standardize=True))) @@ -206,7 +208,7 @@ def test_get_primitive_cell_hex(self): [0.77, 1.57, 5.74], ] cell = [[2.519, 1.454, 4.590], [-2.519, 1.454, 4.590], [0.0, -2.909, 4.590]] - structure = Atoms(symbols=elements, positions=positions, cell=cell) + structure = Atoms(symbols=elements, positions=positions, cell=cell, pbc=True) structure_repeat = structure.repeat([2, 2, 2]) sym = stk.analyse.get_symmetry(structure=structure_repeat) structure_prim_base = sym.get_primitive_cell() From f3bb4b1363e37c1dc8340874e982051c975627d2 Mon Sep 17 00:00:00 2001 From: Marvin Poul Date: Tue, 7 Apr 2026 12:08:26 -0400 Subject: [PATCH 2/9] Test case where primittive_cell returns a larger cell due to standardization --- tests/test_symmetry.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_symmetry.py b/tests/test_symmetry.py index 3fcca286b..893998592 100644 --- a/tests/test_symmetry.py +++ b/tests/test_symmetry.py @@ -193,6 +193,20 @@ def test_get_primitive_cell(self): 221, ) + def test_get_primitive_cell_standardize_fcc(self): + # primitive FCC cell has 1 atom; standardize=True should return the + # conventional cubic cell with 4 atoms + a_0 = 4.05 + structure = bulk("Al", crystalstructure="fcc", a=a_0) + self.assertEqual(len(structure), 1) + sym = stk.analyse.get_symmetry(structure=structure) + std = sym.get_primitive_cell(standardize=True) + self.assertEqual(len(std), 4) + # conventional cell should be approximately cubic + cell = std.get_cell() + lengths = np.linalg.norm(cell, axis=1) + self.assertTrue(np.allclose(lengths, a_0, atol=1e-3)) + def test_get_primitive_cell_hex(self): elements = ["Fe", "Fe", "Fe", "Fe", "O", "O", "O", "O", "O", "O"] positions = [ From ed840c83914a54b1538194fc0baf9d2035f75913 Mon Sep 17 00:00:00 2001 From: Marvin Poul Date: Tue, 7 Apr 2026 12:30:45 -0400 Subject: [PATCH 3/9] merge symmetry tests --- tests/test_analyse_symmetry.py | 311 +++++++++++++++++++++++++++++- tests/test_symmetry.py | 340 --------------------------------- 2 files changed, 306 insertions(+), 345 deletions(-) delete mode 100644 tests/test_symmetry.py diff --git a/tests/test_analyse_symmetry.py b/tests/test_analyse_symmetry.py index 7f956e91d..c084735e9 100644 --- a/tests/test_analyse_symmetry.py +++ b/tests/test_analyse_symmetry.py @@ -20,6 +20,7 @@ try: import spglib + from spglib.error import SpglibError skip_spglib_test = False except ImportError: @@ -31,6 +32,39 @@ ) class TestSymmetry(unittest.TestCase): def test_get_arg_equivalent_sites(self): + a_0 = 4.0 + structure = bulk("Al", cubic=True, a=a_0).repeat(2) + sites = stk.common.get_wrapped_coordinates( + structure=structure, + positions=structure.positions + np.array([0, 0, 0.5 * a_0]), + ) + v_position = structure.positions[0] + del structure[0] + pairs = np.stack( + ( + stk.analyse.get_symmetry(structure=structure).get_arg_equivalent_sites( + sites + ), + np.unique( + np.round( + stk.analyse.get_distances_array( + structure=structure, p1=v_position, p2=sites + ), + decimals=2, + ), + return_inverse=True, + )[1], + ), + axis=-1, + ) + unique_pairs = np.unique(pairs, axis=0) + self.assertEqual(len(unique_pairs), len(np.unique(unique_pairs[:, 0]))) + with self.assertRaises(ValueError): + stk.analyse.get_symmetry(structure=structure).get_arg_equivalent_sites( + [0, 0, 0] + ) + + def test_group_points_by_symmetry(self): a_0 = 4.0 structure = bulk("Al", cubic=True, a=a_0).repeat(2) sites = stk.common.get_wrapped_coordinates( @@ -65,15 +99,115 @@ def test_get_arg_equivalent_sites(self): points=[0, 0, 0], ) + def test_generate_equivalent_points(self): + a_0 = 4 + structure = bulk("Al", cubic=True, a=a_0) + sym = stk.analyse.get_symmetry(structure) + self.assertEqual( + len(structure), len(sym.generate_equivalent_points([0, 0, 0.5 * a_0])) + ) + x = np.array([[0, 0, 0.5 * a_0], 3 * [0.25 * a_0]]) + y = np.random.randn(2) + sym_x = sym.generate_equivalent_points(x, return_unique=False) + y = np.tile(y, len(sym_x)) + sym_x = sym_x.reshape(-1, 3) + xy = np.round( + [ + stk.analyse.get_neighborhood( + structure, sym_x, num_neighbors=1 + ).distances.flatten(), + y, + ], + decimals=8, + ) + self.assertEqual( + np.unique(xy, axis=1).shape, + (2, 2), + msg="order of generated points does not match the original order", + ) + + def test_get_symmetry(self): + cell = 2.2 * np.identity(3) + Al = Atoms( + "AlAl", positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True + ).repeat(2) + self.assertEqual( + len(set(stk.analyse.get_symmetry(structure=Al)["equivalent_atoms"])), 1 + ) + self.assertEqual( + len(stk.analyse.get_symmetry(structure=Al)["translations"]), 96 + ) + self.assertEqual( + len(stk.analyse.get_symmetry(structure=Al)["translations"]), + len(stk.analyse.get_symmetry(structure=Al)["rotations"]), + ) + cell = 2.2 * np.identity(3) + Al = Atoms( + "AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True + ) + v = np.random.rand(6).reshape(-1, 3) + sym = stk.analyse.get_symmetry(structure=Al) + self.assertAlmostEqual( + np.linalg.norm(sym.symmetrize_vectors(v)), + 0, + ) + vv = np.random.rand(12).reshape(2, 2, 3) + for vvv in sym.symmetrize_vectors(vv): + self.assertAlmostEqual(np.linalg.norm(vvv), 0) + Al.positions[0, 0] += 0.01 + w = sym.symmetrize_vectors(v) + self.assertAlmostEqual( + np.absolute(w[:, 0]).sum(), np.linalg.norm(w, axis=-1).sum() + ) + self.assertAlmostEqual( + np.linalg.norm(sym.symmetrize_vectors(v) - sym.symmetrize_tensor(v)), 0 + ) + + def test_symmetrize_tensor(self): + structure = Atoms( + "AlAlAlAl", + positions=[(0, 0, 0), (0, 0.5, 0.5), (0.5, 0, 0.5), (0.5, 0.5, 0)], + cell=np.identity(3), + pbc=True, + ).repeat(2) + structure.symbols[0] = "Ni" + symmetry = stk.analyse.get_symmetry(structure=structure) + self.assertLess(np.ptp(symmetry.symmetrize_tensor(np.random.randn(3))), 1.0e-8) + sym_tensor = symmetry.symmetrize_tensor(np.random.randn(3, 3)) + self.assertLess(np.ptp(sym_tensor.diagonal()), 1.0e-8) + self.assertLess(np.ptp(sym_tensor[np.triu_indices(3, k=1)]), 1.0e-8) + i = np.all(structure.positions == [0.5, 0, 0.5], axis=-1) + j = np.all(structure.positions == [0, 0.5, 0.5], axis=-1) + s_tensor = symmetry.symmetrize_tensor(np.random.randn(len(structure))) + self.assertAlmostEqual(s_tensor[i][0], s_tensor[j][0]) + s_tensor = symmetry.symmetrize_tensor( + np.random.randn(4, len(structure), 3, len(structure), 3) + ) + self.assertEqual(s_tensor.shape, (4, len(structure), 3, len(structure), 3)) + s_tensor = symmetry.symmetrize_tensor( + np.random.randn(4, len(structure), 3, 3, len(structure)) + ) + self.assertEqual(s_tensor.shape, (4, len(structure), 3, 3, len(structure))) + structure_displaced = structure.copy() + structure_displaced.positions[0, 0] += 0.01 + sym = stk.analyse.get_symmetry(structure=structure_displaced) + tensor = np.zeros((len(structure_displaced), 3, len(structure_displaced), 3)) + tensor[0, 0, 0, 0] = 1 + self.assertAlmostEqual(sym.symmetrize_tensor(tensor)[0, 0, 0, 0], 1) + def test_get_symmetry_dataset(self): cell = 2.2 * np.identity(3) - Al_sc = Atoms("AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell) + Al_sc = Atoms( + "AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True + ) Al_sc = Al_sc.repeat([2, 2, 2]) self.assertEqual(stk.analyse.get_symmetry(structure=Al_sc).info["number"], 229) def test_get_ir_reciprocal_mesh(self): cell = 2.2 * np.identity(3) - Al_sc = Atoms("AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell) + Al_sc = Atoms( + "AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True + ) self.assertEqual( len( stk.analyse.get_symmetry(structure=Al_sc).get_ir_reciprocal_mesh( @@ -85,7 +219,24 @@ def test_get_ir_reciprocal_mesh(self): def test_get_primitive_cell(self): cell = 2.2 * np.identity(3) - basis = Atoms("AlFe", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell) + basis = Atoms( + "AlFe", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True + ) + structure = basis.repeat([2, 2, 2]) + sym = stk.analyse.get_symmetry(structure=structure) + self.assertEqual(len(basis), len(sym.get_primitive_cell(standardize=True))) + self.assertEqual( + stk.analyse.get_symmetry(structure=sym.get_primitive_cell()).spacegroup[ + "Number" + ], + 221, + ) + + def test_get_primitive_cell_functional(self): + cell = 2.2 * np.identity(3) + basis = Atoms( + "AlFe", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True + ) structure = basis.repeat([2, 2, 2]) self.assertEqual( len(basis), len(stk.analyse.get_primitive_cell(structure=structure)) @@ -97,6 +248,20 @@ def test_get_primitive_cell(self): 221, ) + def test_get_primitive_cell_standardize_fcc(self): + # primitive FCC cell has 1 atom; standardize=True should return the + # conventional cubic cell with 4 atoms + a_0 = 4.05 + structure = bulk("Al", crystalstructure="fcc", a=a_0) + self.assertEqual(len(structure), 1) + sym = stk.analyse.get_symmetry(structure=structure) + std = sym.get_primitive_cell(standardize=True) + self.assertEqual(len(std), 4) + # conventional cell should be approximately cubic + cell = std.get_cell() + lengths = np.linalg.norm(cell, axis=1) + self.assertTrue(np.allclose(lengths, a_0, atol=1e-3)) + def test_get_primitive_cell_hex(self): elements = ["Fe", "Fe", "Fe", "Fe", "O", "O", "O", "O", "O", "O"] positions = [ @@ -112,7 +277,30 @@ def test_get_primitive_cell_hex(self): [0.77, 1.57, 5.74], ] cell = [[2.519, 1.454, 4.590], [-2.519, 1.454, 4.590], [0.0, -2.909, 4.590]] - structure = Atoms(symbols=elements, positions=positions, cell=cell) + structure = Atoms(symbols=elements, positions=positions, cell=cell, pbc=True) + structure_repeat = structure.repeat([2, 2, 2]) + sym = stk.analyse.get_symmetry(structure=structure_repeat) + structure_prim_base = sym.get_primitive_cell() + self.assertEqual( + structure_prim_base.get_chemical_symbols(), structure.get_chemical_symbols() + ) + + def test_get_primitive_cell_hex_functional(self): + elements = ["Fe", "Fe", "Fe", "Fe", "O", "O", "O", "O", "O", "O"] + positions = [ + [0.0, 0.0, 4.89], + [0.0, 0.0, 11.78], + [0.0, 0.0, 1.99], + [0.0, 0.0, 8.87], + [-0.98, 1.45, 8.0], + [-1.74, -0.1, 5.74], + [-0.77, -1.57, 8.0], + [0.98, -1.45, 5.74], + [1.74, 0.12, 8.0], + [0.77, 1.57, 5.74], + ] + cell = [[2.519, 1.454, 4.590], [-2.519, 1.454, 4.590], [0.0, -2.909, 4.590]] + structure = Atoms(symbols=elements, positions=positions, cell=cell, pbc=True) structure_repeat = structure.repeat([2, 2, 2]) structure_prim_base = stk.analyse.get_primitive_cell(structure=structure_repeat) self.assertEqual( @@ -121,7 +309,24 @@ def test_get_primitive_cell_hex(self): def test_get_equivalent_points(self): basis = Atoms( - "FeFe", positions=[[0.01, 0, 0], [0.5, 0.5, 0.5]], cell=np.identity(3) + "FeFe", + positions=[[0.01, 0, 0], [0.5, 0.5, 0.5]], + cell=np.identity(3), + pbc=True, + ) + arr = stk.analyse.get_symmetry(structure=basis).generate_equivalent_points( + [0, 0, 0.5] + ) + self.assertAlmostEqual( + np.linalg.norm(arr - np.array([0.51, 0.5, 0]), axis=-1).min(), 0 + ) + + def test_get_equivalent_points_functional(self): + basis = Atoms( + "FeFe", + positions=[[0.01, 0, 0], [0.5, 0.5, 0.5]], + cell=np.identity(3), + pbc=True, ) arr = stk.analyse.get_equivalent_points( structure=basis, @@ -132,6 +337,102 @@ def test_get_equivalent_points(self): 0.7142128534267638, ) + def test_get_space_group(self): + cell = 2.2 * np.identity(3) + Al_sc = Atoms( + "AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True + ) + self.assertEqual( + stk.analyse.get_symmetry(structure=Al_sc).spacegroup[ + "InternationalTableSymbol" + ], + "Im-3m", + ) + self.assertEqual( + stk.analyse.get_symmetry(structure=Al_sc).spacegroup["Number"], 229 + ) + cell = 4.2 * (0.5 * np.ones((3, 3)) - 0.5 * np.eye(3)) + Al_fcc = Atoms("Al", scaled_positions=[(0, 0, 0)], cell=cell, pbc=True) + self.assertEqual( + stk.analyse.get_symmetry(structure=Al_fcc).spacegroup[ + "InternationalTableSymbol" + ], + "Fm-3m", + ) + self.assertEqual( + stk.analyse.get_symmetry(structure=Al_fcc).spacegroup["Number"], 225 + ) + a = 3.18 + c = 1.623 * a + cell = np.eye(3) + cell[0, 0] = a + cell[2, 2] = c + cell[1, 0] = -a / 2.0 + cell[1, 1] = np.sqrt(3) * a / 2.0 + pos = np.array([[0.0, 0.0, 0.0], [1.0 / 3.0, 2.0 / 3.0, 1.0 / 2.0]]) + Mg_hcp = Atoms("Mg2", scaled_positions=pos, cell=cell, pbc=True) + self.assertEqual( + stk.analyse.get_symmetry(structure=Mg_hcp).spacegroup["Number"], 194 + ) + cell = np.eye(3) + cell[0, 0] = a + cell[2, 2] = c + cell[1, 1] = np.sqrt(3) * a + pos = np.array( + [ + [0.0, 0.0, 0.0], + [0.5, 0.5, 0.0], + [0.5, 1 / 6, 0.5], + [0.0, 2 / 3, 0.5], + ] + ) + Mg_hcp = Atoms("Mg4", scaled_positions=pos, cell=cell, pbc=True) + self.assertEqual( + stk.analyse.get_symmetry(structure=Mg_hcp).spacegroup["Number"], 194 + ) + + def test_permutations(self): + structure = bulk("Al", cubic=True).repeat(2) + x_vacancy = structure.positions[0] + del structure[0] + neigh = stk.analyse.get_neighborhood(structure=structure, positions=x_vacancy) + vec = np.zeros_like(structure.positions) + vec[neigh.indices[0]] = neigh.vecs[0] + sym = stk.analyse.get_symmetry(structure=structure) + all_vectors = np.einsum("ijk,ink->inj", sym.rotations, vec[sym.permutations]) + for i, v in zip(neigh.indices, neigh.vecs, strict=True): + vec = np.zeros_like(structure.positions) + vec[i] = v + self.assertAlmostEqual( + np.linalg.norm(all_vectors - vec, axis=(-1, -2)).min(), + 0, + ) + + def test_arg_equivalent_vectors(self): + structure = bulk("Al", cubic=True).repeat(2) + self.assertEqual( + np.unique( + stk.analyse.get_symmetry(structure=structure).arg_equivalent_vectors + ).squeeze(), + 0, + ) + x_v = structure.positions[0] + del structure[0] + arg_v = stk.analyse.get_symmetry(structure=structure).arg_equivalent_vectors + dx = stk.analyse.get_distances_array( + structure=structure, p1=structure.positions, p2=x_v, vectors=True + ) + dx_round = np.round(np.absolute(dx), decimals=3) + self.assertEqual(len(np.unique(dx_round + arg_v)), len(np.unique(arg_v))) + + def test_error(self): + """spglib errors should be wrapped in a SymmetryError.""" + + structure = bulk("Al") + structure += structure[-1] + with self.assertRaises((stk.common.SymmetryError, SpglibError)): + stk.analyse.get_symmetry(structure=structure) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_symmetry.py b/tests/test_symmetry.py deleted file mode 100644 index 893998592..000000000 --- a/tests/test_symmetry.py +++ /dev/null @@ -1,340 +0,0 @@ -# coding: utf-8 -# Copyright (c) Max-Planck-Institut für Eisenforschung GmbH - Computational Materials Design (CM) Department -# Distributed under the terms of "New BSD License", see the LICENSE file. - -import unittest - -import numpy as np -from ase.atoms import Atoms -from ase.build import bulk - -import structuretoolkit as stk - -try: - import pyscal - - skip_pyscal_test = False -except ImportError: - skip_pyscal_test = True - - -try: - import spglib - from spglib.error import SpglibError - - skip_spglib_test = False -except ImportError: - skip_spglib_test = True - - -@unittest.skipIf( - skip_spglib_test, "spglib is not installed, so the spglib tests are skipped." -) -class TestSymmetry(unittest.TestCase): - def test_get_arg_equivalent_sites(self): - a_0 = 4.0 - structure = bulk("Al", cubic=True, a=a_0).repeat(2) - sites = stk.common.get_wrapped_coordinates( - structure=structure, - positions=structure.positions + np.array([0, 0, 0.5 * a_0]), - ) - v_position = structure.positions[0] - del structure[0] - pairs = np.stack( - ( - stk.analyse.get_symmetry(structure=structure).get_arg_equivalent_sites( - sites - ), - np.unique( - np.round( - stk.analyse.get_distances_array( - structure=structure, p1=v_position, p2=sites - ), - decimals=2, - ), - return_inverse=True, - )[1], - ), - axis=-1, - ) - unique_pairs = np.unique(pairs, axis=0) - self.assertEqual(len(unique_pairs), len(np.unique(unique_pairs[:, 0]))) - with self.assertRaises(ValueError): - stk.analyse.get_symmetry(structure=structure).get_arg_equivalent_sites( - [0, 0, 0] - ) - - def test_generate_equivalent_points(self): - a_0 = 4 - structure = bulk("Al", cubic=True, a=a_0) - sym = stk.analyse.get_symmetry(structure) - self.assertEqual( - len(structure), len(sym.generate_equivalent_points([0, 0, 0.5 * a_0])) - ) - x = np.array([[0, 0, 0.5 * a_0], 3 * [0.25 * a_0]]) - y = np.random.randn(2) - sym_x = sym.generate_equivalent_points(x, return_unique=False) - y = np.tile(y, len(sym_x)) - sym_x = sym_x.reshape(-1, 3) - xy = np.round( - [ - stk.analyse.get_neighborhood( - structure, sym_x, num_neighbors=1 - ).distances.flatten(), - y, - ], - decimals=8, - ) - self.assertEqual( - np.unique(xy, axis=1).shape, - (2, 2), - msg="order of generated points does not match the original order", - ) - - def test_get_symmetry(self): - cell = 2.2 * np.identity(3) - Al = Atoms( - "AlAl", positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True - ).repeat(2) - self.assertEqual( - len(set(stk.analyse.get_symmetry(structure=Al)["equivalent_atoms"])), 1 - ) - self.assertEqual( - len(stk.analyse.get_symmetry(structure=Al)["translations"]), 96 - ) - self.assertEqual( - len(stk.analyse.get_symmetry(structure=Al)["translations"]), - len(stk.analyse.get_symmetry(structure=Al)["rotations"]), - ) - cell = 2.2 * np.identity(3) - Al = Atoms( - "AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True - ) - v = np.random.rand(6).reshape(-1, 3) - sym = stk.analyse.get_symmetry(structure=Al) - self.assertAlmostEqual( - np.linalg.norm(sym.symmetrize_vectors(v)), - 0, - ) - vv = np.random.rand(12).reshape(2, 2, 3) - for vvv in sym.symmetrize_vectors(vv): - self.assertAlmostEqual(np.linalg.norm(vvv), 0) - Al.positions[0, 0] += 0.01 - w = sym.symmetrize_vectors(v) - self.assertAlmostEqual( - np.absolute(w[:, 0]).sum(), np.linalg.norm(w, axis=-1).sum() - ) - self.assertAlmostEqual( - np.linalg.norm(sym.symmetrize_vectors(v) - sym.symmetrize_tensor(v)), 0 - ) - - def test_symmetrize_tensor(self): - structure = Atoms( - "AlAlAlAl", - positions=[(0, 0, 0), (0, 0.5, 0.5), (0.5, 0, 0.5), (0.5, 0.5, 0)], - cell=np.identity(3), - pbc=True, - ).repeat(2) - structure.symbols[0] = "Ni" - symmetry = stk.analyse.get_symmetry(structure=structure) - self.assertLess(np.ptp(symmetry.symmetrize_tensor(np.random.randn(3))), 1.0e-8) - sym_tensor = symmetry.symmetrize_tensor(np.random.randn(3, 3)) - self.assertLess(np.ptp(sym_tensor.diagonal()), 1.0e-8) - self.assertLess(np.ptp(sym_tensor[np.triu_indices(3, k=1)]), 1.0e-8) - i = np.all(structure.positions == [0.5, 0, 0.5], axis=-1) - j = np.all(structure.positions == [0, 0.5, 0.5], axis=-1) - s_tensor = symmetry.symmetrize_tensor(np.random.randn(len(structure))) - self.assertAlmostEqual(s_tensor[i][0], s_tensor[j][0]) - s_tensor = symmetry.symmetrize_tensor( - np.random.randn(4, len(structure), 3, len(structure), 3) - ) - self.assertEqual(s_tensor.shape, (4, len(structure), 3, len(structure), 3)) - s_tensor = symmetry.symmetrize_tensor( - np.random.randn(4, len(structure), 3, 3, len(structure)) - ) - self.assertEqual(s_tensor.shape, (4, len(structure), 3, 3, len(structure))) - structure_displaced = structure.copy() - structure_displaced.positions[0, 0] += 0.01 - sym = stk.analyse.get_symmetry(structure=structure_displaced) - tensor = np.zeros((len(structure_displaced), 3, len(structure_displaced), 3)) - tensor[0, 0, 0, 0] = 1 - self.assertAlmostEqual(sym.symmetrize_tensor(tensor)[0, 0, 0, 0], 1) - - def test_get_symmetry_dataset(self): - cell = 2.2 * np.identity(3) - Al_sc = Atoms("AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell) - Al_sc = Al_sc.repeat([2, 2, 2]) - self.assertEqual(stk.analyse.get_symmetry(structure=Al_sc).info["number"], 229) - - def test_get_ir_reciprocal_mesh(self): - cell = 2.2 * np.identity(3) - Al_sc = Atoms("AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell) - self.assertEqual( - len( - stk.analyse.get_symmetry(structure=Al_sc).get_ir_reciprocal_mesh( - [3, 3, 3] - )[0] - ), - 27, - ) - - def test_get_primitive_cell(self): - cell = 2.2 * np.identity(3) - basis = Atoms( - "AlFe", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell, pbc=True - ) - structure = basis.repeat([2, 2, 2]) - sym = stk.analyse.get_symmetry(structure=structure) - self.assertEqual(len(basis), len(sym.get_primitive_cell(standardize=True))) - self.assertEqual( - stk.analyse.get_symmetry(structure=sym.get_primitive_cell()).spacegroup[ - "Number" - ], - 221, - ) - - def test_get_primitive_cell_standardize_fcc(self): - # primitive FCC cell has 1 atom; standardize=True should return the - # conventional cubic cell with 4 atoms - a_0 = 4.05 - structure = bulk("Al", crystalstructure="fcc", a=a_0) - self.assertEqual(len(structure), 1) - sym = stk.analyse.get_symmetry(structure=structure) - std = sym.get_primitive_cell(standardize=True) - self.assertEqual(len(std), 4) - # conventional cell should be approximately cubic - cell = std.get_cell() - lengths = np.linalg.norm(cell, axis=1) - self.assertTrue(np.allclose(lengths, a_0, atol=1e-3)) - - def test_get_primitive_cell_hex(self): - elements = ["Fe", "Fe", "Fe", "Fe", "O", "O", "O", "O", "O", "O"] - positions = [ - [0.0, 0.0, 4.89], - [0.0, 0.0, 11.78], - [0.0, 0.0, 1.99], - [0.0, 0.0, 8.87], - [-0.98, 1.45, 8.0], - [-1.74, -0.1, 5.74], - [-0.77, -1.57, 8.0], - [0.98, -1.45, 5.74], - [1.74, 0.12, 8.0], - [0.77, 1.57, 5.74], - ] - cell = [[2.519, 1.454, 4.590], [-2.519, 1.454, 4.590], [0.0, -2.909, 4.590]] - structure = Atoms(symbols=elements, positions=positions, cell=cell, pbc=True) - structure_repeat = structure.repeat([2, 2, 2]) - sym = stk.analyse.get_symmetry(structure=structure_repeat) - structure_prim_base = sym.get_primitive_cell() - self.assertEqual( - structure_prim_base.get_chemical_symbols(), structure.get_chemical_symbols() - ) - - def test_get_equivalent_points(self): - basis = Atoms( - "FeFe", positions=[[0.01, 0, 0], [0.5, 0.5, 0.5]], cell=np.identity(3) - ) - arr = stk.analyse.get_symmetry(structure=basis).generate_equivalent_points( - [0, 0, 0.5] - ) - self.assertAlmostEqual( - np.linalg.norm(arr - np.array([0.51, 0.5, 0]), axis=-1).min(), 0 - ) - - def test_get_space_group(self): - cell = 2.2 * np.identity(3) - Al_sc = Atoms("AlAl", scaled_positions=[(0, 0, 0), (0.5, 0.5, 0.5)], cell=cell) - self.assertEqual( - stk.analyse.get_symmetry(structure=Al_sc).spacegroup[ - "InternationalTableSymbol" - ], - "Im-3m", - ) - self.assertEqual( - stk.analyse.get_symmetry(structure=Al_sc).spacegroup["Number"], 229 - ) - cell = 4.2 * (0.5 * np.ones((3, 3)) - 0.5 * np.eye(3)) - Al_fcc = Atoms("Al", scaled_positions=[(0, 0, 0)], cell=cell) - self.assertEqual( - stk.analyse.get_symmetry(structure=Al_fcc).spacegroup[ - "InternationalTableSymbol" - ], - "Fm-3m", - ) - self.assertEqual( - stk.analyse.get_symmetry(structure=Al_fcc).spacegroup["Number"], 225 - ) - a = 3.18 - c = 1.623 * a - cell = np.eye(3) - cell[0, 0] = a - cell[2, 2] = c - cell[1, 0] = -a / 2.0 - cell[1, 1] = np.sqrt(3) * a / 2.0 - pos = np.array([[0.0, 0.0, 0.0], [1.0 / 3.0, 2.0 / 3.0, 1.0 / 2.0]]) - Mg_hcp = Atoms("Mg2", scaled_positions=pos, cell=cell) - self.assertEqual( - stk.analyse.get_symmetry(structure=Mg_hcp).spacegroup["Number"], 194 - ) - cell = np.eye(3) - cell[0, 0] = a - cell[2, 2] = c - cell[1, 1] = np.sqrt(3) * a - pos = np.array( - [ - [0.0, 0.0, 0.0], - [0.5, 0.5, 0.0], - [0.5, 1 / 6, 0.5], - [0.0, 2 / 3, 0.5], - ] - ) - Mg_hcp = Atoms("Mg4", scaled_positions=pos, cell=cell) - self.assertEqual( - stk.analyse.get_symmetry(structure=Mg_hcp).spacegroup["Number"], 194 - ) - - def test_permutations(self): - structure = bulk("Al", cubic=True).repeat(2) - x_vacancy = structure.positions[0] - del structure[0] - neigh = stk.analyse.get_neighborhood(structure=structure, positions=x_vacancy) - vec = np.zeros_like(structure.positions) - vec[neigh.indices[0]] = neigh.vecs[0] - sym = stk.analyse.get_symmetry(structure=structure) - all_vectors = np.einsum("ijk,ink->inj", sym.rotations, vec[sym.permutations]) - for i, v in zip(neigh.indices, neigh.vecs, strict=True): - vec = np.zeros_like(structure.positions) - vec[i] = v - self.assertAlmostEqual( - np.linalg.norm(all_vectors - vec, axis=(-1, -2)).min(), - 0, - ) - - def test_arg_equivalent_vectors(self): - structure = bulk("Al", cubic=True).repeat(2) - self.assertEqual( - np.unique( - stk.analyse.get_symmetry(structure=structure).arg_equivalent_vectors - ).squeeze(), - 0, - ) - x_v = structure.positions[0] - del structure[0] - arg_v = stk.analyse.get_symmetry(structure=structure).arg_equivalent_vectors - dx = stk.analyse.get_distances_array( - structure=structure, p1=structure.positions, p2=x_v, vectors=True - ) - dx_round = np.round(np.absolute(dx), decimals=3) - self.assertEqual(len(np.unique(dx_round + arg_v)), len(np.unique(arg_v))) - - def test_error(self): - """spglib errors should be wrapped in a SymmetryError.""" - - structure = bulk("Al") - structure += structure[-1] - with self.assertRaises((stk.common.SymmetryError, SpglibError)): - stk.analyse.get_symmetry(structure=structure) - - -if __name__ == "__main__": - unittest.main() From d663bbd8e6bf43bf149f17eeaccf4c3f395628f6 Mon Sep 17 00:00:00 2001 From: Marvin Poul Date: Tue, 7 Apr 2026 16:49:05 +0000 Subject: [PATCH 4/9] Update src/structuretoolkit/analyse/symmetry.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/structuretoolkit/analyse/symmetry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/structuretoolkit/analyse/symmetry.py b/src/structuretoolkit/analyse/symmetry.py index f6fbb0e03..1f767623c 100644 --- a/src/structuretoolkit/analyse/symmetry.py +++ b/src/structuretoolkit/analyse/symmetry.py @@ -413,7 +413,7 @@ def get_primitive_cell( to_primitive=not standardize, ) if ret is None: - raise SymmetryError(spglib.spglib.spglib_error.message) + raise SymmetryError(spglib.error.get_error_message()) cell, scaled_positions, indices = ret indices_dict = { v: k From 8bcd71fa097517f1cd8e5b1cde1a076cc7ce37ab Mon Sep 17 00:00:00 2001 From: Marvin Poul Date: Tue, 7 Apr 2026 16:49:53 +0000 Subject: [PATCH 5/9] Update src/structuretoolkit/analyse/symmetry.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/structuretoolkit/analyse/symmetry.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/structuretoolkit/analyse/symmetry.py b/src/structuretoolkit/analyse/symmetry.py index 1f767623c..f72491cff 100644 --- a/src/structuretoolkit/analyse/symmetry.py +++ b/src/structuretoolkit/analyse/symmetry.py @@ -422,18 +422,13 @@ def get_primitive_cell( ).items() } symbols = [indices_dict[i] for i in indices] - arrays = { - k: self._structure.arrays[k] - for k in self._structure.arrays - if k not in ("numbers", "positions") - } new_structure = type(self._structure)( symbols=symbols, scaled_positions=scaled_positions, cell=cell, pbc=[True, True, True], ) - keys = set(arrays) - {"numbers", "positions"} + keys = set(self._structure.arrays) - {"numbers", "positions"} if len(keys) > 0: warning(f"Custom arrays {keys} do not carry over to new structure!") From 15726ec80a27a25a1e295069fe74154cb326dcb7 Mon Sep 17 00:00:00 2001 From: Marvin Poul Date: Tue, 7 Apr 2026 16:51:18 +0000 Subject: [PATCH 6/9] Apply suggestion from @pmrv --- src/structuretoolkit/analyse/symmetry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/structuretoolkit/analyse/symmetry.py b/src/structuretoolkit/analyse/symmetry.py index f72491cff..2f114848f 100644 --- a/src/structuretoolkit/analyse/symmetry.py +++ b/src/structuretoolkit/analyse/symmetry.py @@ -404,7 +404,7 @@ def get_primitive_cell( .. warning:: Custom arrays defined in the base structures - :attr:`ase.atoms.Atoms.arrays` are not copied to the new structure! + :attr:`ase.atoms.Atoms.arrays` and other state (.info, .calc, etc.) are not copied to the new structure! """ if not all(self._structure.pbc): raise ValueError("Can only symmetrize periodic structures.") From c0d2d3677e9028c668bf1940de243fa9a5e52e4f Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Fri, 10 Apr 2026 22:11:12 +0200 Subject: [PATCH 7/9] Apply suggestions from code review Co-authored-by: Jan Janssen --- src/structuretoolkit/analyse/symmetry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/structuretoolkit/analyse/symmetry.py b/src/structuretoolkit/analyse/symmetry.py index 2f114848f..3f1aff3af 100644 --- a/src/structuretoolkit/analyse/symmetry.py +++ b/src/structuretoolkit/analyse/symmetry.py @@ -4,7 +4,7 @@ import ast import dataclasses import string -from logging import warning +import warnings import numpy as np import spglib @@ -430,7 +430,7 @@ def get_primitive_cell( ) keys = set(self._structure.arrays) - {"numbers", "positions"} if len(keys) > 0: - warning(f"Custom arrays {keys} do not carry over to new structure!") + warnings.warn(f"Custom arrays {keys} do not carry over to new structure!", stacklevel=2) return new_structure From 012dadd6eaf32bec6ad164e6601f662480abd991 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Apr 2026 20:11:17 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/structuretoolkit/analyse/symmetry.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/structuretoolkit/analyse/symmetry.py b/src/structuretoolkit/analyse/symmetry.py index 3f1aff3af..83f89f2ce 100644 --- a/src/structuretoolkit/analyse/symmetry.py +++ b/src/structuretoolkit/analyse/symmetry.py @@ -430,7 +430,10 @@ def get_primitive_cell( ) keys = set(self._structure.arrays) - {"numbers", "positions"} if len(keys) > 0: - warnings.warn(f"Custom arrays {keys} do not carry over to new structure!", stacklevel=2) + warnings.warn( + f"Custom arrays {keys} do not carry over to new structure!", + stacklevel=2, + ) return new_structure From 092cd63139818e433abb8c3666bdb7c512c53f0e Mon Sep 17 00:00:00 2001 From: Jan Janssen Date: Fri, 10 Apr 2026 22:33:28 +0200 Subject: [PATCH 9/9] Extend unit tests for get_primitive_cell() (#485) * Extend unit tests for get_primitive_cell() Added test cases to cover: - Periodic boundary condition validation: ensuring ValueError is raised if structure is not periodic. - Custom arrays warning: ensuring a warning is logged when custom arrays are present in the source structure. Co-authored-by: jan-janssen <3854739+jan-janssen@users.noreply.github.com> * fixes * another fix --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: jan-janssen <3854739+jan-janssen@users.noreply.github.com> --- tests/test_analyse_symmetry.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_analyse_symmetry.py b/tests/test_analyse_symmetry.py index c084735e9..aa09acd59 100644 --- a/tests/test_analyse_symmetry.py +++ b/tests/test_analyse_symmetry.py @@ -3,6 +3,7 @@ # Distributed under the terms of "New BSD License", see the LICENSE file. import unittest +import warnings import numpy as np from ase.atoms import Atoms @@ -425,6 +426,28 @@ def test_arg_equivalent_vectors(self): dx_round = np.round(np.absolute(dx), decimals=3) self.assertEqual(len(np.unique(dx_round + arg_v)), len(np.unique(arg_v))) + def test_get_primitive_cell_pbc_error(self): + structure = bulk("Al", cubic=True) + structure.pbc = [True, True, False] + sym = stk.analyse.get_symmetry(structure=structure) + with self.assertRaisesRegex(ValueError, "Can only symmetrize periodic structures."): + sym.get_primitive_cell() + + def test_get_primitive_cell_arrays_warning(self): + structure = bulk("Al", cubic=True) + structure.set_array("test_array", np.zeros(len(structure))) + sym = stk.analyse.get_symmetry(structure=structure) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + sym.get_primitive_cell() + self.assertTrue( + any( + "Custom arrays {'test_array'} do not carry over to new structure!" + in str(warning.message) + for warning in w + ) + ) + def test_error(self): """spglib errors should be wrapped in a SymmetryError."""