diff --git a/platform/jewel/foundation/src/main/kotlin/org/jetbrains/jewel/foundation/lazy/tree/BasicLazyTree.kt b/platform/jewel/foundation/src/main/kotlin/org/jetbrains/jewel/foundation/lazy/tree/BasicLazyTree.kt index 3a5f9f7b1fe4d..9be1392c1bd52 100644 --- a/platform/jewel/foundation/src/main/kotlin/org/jetbrains/jewel/foundation/lazy/tree/BasicLazyTree.kt +++ b/platform/jewel/foundation/src/main/kotlin/org/jetbrains/jewel/foundation/lazy/tree/BasicLazyTree.kt @@ -29,6 +29,7 @@ import androidx.compose.ui.semantics.semantics import androidx.compose.ui.semantics.stateDescription import androidx.compose.ui.unit.Dp import androidx.compose.ui.unit.dp +import kotlin.collections.forEach import kotlin.time.Duration import kotlin.time.Duration.Companion.milliseconds import org.jetbrains.jewel.foundation.lazy.SelectableLazyColumn @@ -179,7 +180,19 @@ public fun BasicLazyTree( val scope = rememberCoroutineScope() val flattenedTree = - remember(tree, treeState.openNodes, treeState.allNodes) { tree.roots.flatMap { it.flattenTree(treeState) } } + remember(tree, treeState.openNodes) { + // flattenTree will safely repopulate allNodes with only the nodes that still + // exist and are currently visible. + treeState.allNodes.clear() + + val result = tree.roots.flatMap { it.flattenTree(treeState) } + + // Purging stale IDs from openNodes by keeping only the ones that survived the rebuild + val survivingIds = result.map { it.id }.toSet() + treeState.openNodes = treeState.openNodes intersect survivingIds + + result + } remember(tree) { // if tree changes we need to update selection changes onSelectionChange( @@ -202,9 +215,11 @@ public fun BasicLazyTree( pointerEventActions = pointerEventScopedActions, interactionSource = interactionSource, ) { - itemsIndexed(items = flattenedTree, key = { _, item -> item.id }, contentType = { _, item -> item.data }) { - index, - element -> + itemsIndexed( + items = flattenedTree, + key = { _, item -> item.id }, + contentType = { _, item -> item.data?.let { it::class } }, + ) { index, element -> val elementState = TreeElementState.of( active = isActive, @@ -402,7 +417,7 @@ private fun Tree.Element<*>.flattenTree(state: TreeState): MutableList>() when (this) { is Tree.Element.Node<*> -> { - if (id !in state.allNodes.map { it.first }) state.allNodes.add(id to depth) + if (state.allNodes.none { it.first == id }) state.allNodes.add(id to depth) orderedChildren.add(this) if (id !in state.openNodes) { return orderedChildren.also { @@ -422,8 +437,10 @@ private fun Tree.Element<*>.flattenTree(state: TreeState): MutableList.getAllSubNodes(node: Tree.Element.Node<*>) { - node.children?.filterIsInstance>()?.forEach { - add(it.id) - this@getAllSubNodes getAllSubNodes (it) + node.children?.forEach { child -> + if (child is Tree.Element.Node<*>) { + add(child.id) + this@getAllSubNodes getAllSubNodes child + } } }