diff --git a/ast/ast.go b/ast/ast.go index 36ba606f..6f06b932 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -507,10 +507,12 @@ func walkHelper(n Node, walker Walker) (WalkStatus, error) { return status, err } if status != WalkSkipChildren { - for c := n.FirstChild(); c != nil; c = c.NextSibling() { + for c := n.FirstChild(); c != nil; { + next := c.NextSibling() if st, err := walkHelper(c, walker); err != nil || st == WalkStop { return WalkStop, err } + c = next } } status, err = walker(n, false) diff --git a/ast/ast_test.go b/ast/ast_test.go index 191fffd6..93c2d73f 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -52,6 +52,39 @@ func TestWalk(t *testing.T) { } } +func TestWalkAndReplace(t *testing.T) { + doc := node(NewDocument(), node(NewHeading(1), NewLink(), NewLink())) + want := []NodeKind{KindDocument, KindHeading, KindHeading, KindHeading} + var got []NodeKind + walkerReplace := func(n Node, entering bool) (WalkStatus, error) { + // We replace any link by an heading + if entering { + n, ok := n.(*Link) + if !ok { + return WalkContinue, nil + } + parent := n.Parent() + parent.ReplaceChild(parent, n, NewHeading(2)) + } + return WalkContinue, nil + } + walkerCollect := func(n Node, entering bool) (WalkStatus, error) { + if entering { + got = append(got, n.Kind()) + } + return WalkContinue, nil + } + if err := Walk(doc, walkerReplace); err != nil { + t.Fatalf("Walk() error = %v", err) + } + if err := Walk(doc, walkerCollect); err != nil { + t.Fatalf("Walk() error = %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Walk() expected = %v, got = %v", want, got) + } +} + func node(n Node, children ...Node) Node { for _, c := range children { n.AppendChild(n, c)