From cca3fc53854bfdabe7a4b60c0d47357cf2a62884 Mon Sep 17 00:00:00 2001 From: Vincent Bernat Date: Sat, 26 Jul 2025 15:30:57 +0200 Subject: [PATCH] Make ast.Walk() robust against the replacement of the current node When the current node is replaced (with ReplaceChild()), the walk is interrupted as the previous node lost its sibling. This changes make it possible to continue walking. --- ast/ast.go | 4 +++- ast/ast_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/ast/ast.go b/ast/ast.go index 36ba606f..6f06b932 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -507,10 +507,12 @@ func walkHelper(n Node, walker Walker) (WalkStatus, error) { return status, err } if status != WalkSkipChildren { - for c := n.FirstChild(); c != nil; c = c.NextSibling() { + for c := n.FirstChild(); c != nil; { + next := c.NextSibling() if st, err := walkHelper(c, walker); err != nil || st == WalkStop { return WalkStop, err } + c = next } } status, err = walker(n, false) diff --git a/ast/ast_test.go b/ast/ast_test.go index 191fffd6..93c2d73f 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -52,6 +52,39 @@ func TestWalk(t *testing.T) { } } +func TestWalkAndReplace(t *testing.T) { + doc := node(NewDocument(), node(NewHeading(1), NewLink(), NewLink())) + want := []NodeKind{KindDocument, KindHeading, KindHeading, KindHeading} + var got []NodeKind + walkerReplace := func(n Node, entering bool) (WalkStatus, error) { + // We replace any link by an heading + if entering { + n, ok := n.(*Link) + if !ok { + return WalkContinue, nil + } + parent := n.Parent() + parent.ReplaceChild(parent, n, NewHeading(2)) + } + return WalkContinue, nil + } + walkerCollect := func(n Node, entering bool) (WalkStatus, error) { + if entering { + got = append(got, n.Kind()) + } + return WalkContinue, nil + } + if err := Walk(doc, walkerReplace); err != nil { + t.Fatalf("Walk() error = %v", err) + } + if err := Walk(doc, walkerCollect); err != nil { + t.Fatalf("Walk() error = %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Walk() expected = %v, got = %v", want, got) + } +} + func node(n Node, children ...Node) Node { for _, c := range children { n.AppendChild(n, c)