Skip to main content
The LuaSyntaxRewriter is a powerful tool for transforming syntax trees. It allows you to replace or remove nodes while automatically handling the immutable tree structure.

Understanding Immutability

Syntax trees in Loretta are immutable - you cannot modify nodes directly. Instead, you create new versions:
// This doesn't work - nodes are immutable
node.SomeProperty = newValue; // Compile error

// Instead, use With* methods
var newNode = node.WithSomeProperty(newValue);

// Original node is unchanged
var original = node.SomeProperty;
var updated = newNode.SomeProperty;

How LuaSyntaxRewriter Works

LuaSyntaxRewriter descends through the tree and gives you the chance to replace each node:
using Loretta.CodeAnalysis.Lua;

public class MyRewriter : LuaSyntaxRewriter
{
    public override SyntaxNode? VisitIdentifierName(
        IdentifierNameSyntax node)
    {
        // Return a modified node, or the original
        if (node.Name == "old")
        {
            return SyntaxFactory.IdentifierName("new");
        }
        return base.VisitIdentifierName(node);
    }
}

// Usage
var rewriter = new MyRewriter();
var newRoot = rewriter.Visit(root);
The rewriter automatically:
  • Updates all parent nodes
  • Maintains tree structure
  • Handles node removal (return null or SyntaxKind.None)

Creating Modified Trees

Rewriters create new tree versions by replacing nodes:
using Loretta.CodeAnalysis;
using Loretta.CodeAnalysis.Lua;
using Loretta.CodeAnalysis.Lua.Syntax;

public class ConstantFolder : LuaSyntaxRewriter
{
    public override SyntaxNode? VisitBinaryExpression(
        BinaryExpressionSyntax node)
    {
        // First, rewrite children
        var rewritten = (BinaryExpressionSyntax?)base.VisitBinaryExpression(node);
        if (rewritten == null) return null;

        // Try to fold constants
        if (rewritten.Left is LiteralExpressionSyntax left &&
            rewritten.Right is LiteralExpressionSyntax right)
        {
            if (left.Token.Value is double leftVal &&
                right.Token.Value is double rightVal)
            {
                double? result = rewritten.Kind() switch
                {
                    SyntaxKind.AddExpression => leftVal + rightVal,
                    SyntaxKind.SubtractExpression => leftVal - rightVal,
                    SyntaxKind.MultiplyExpression => leftVal * rightVal,
                    SyntaxKind.DivideExpression => rightVal != 0 ? leftVal / rightVal : null,
                    _ => null
                };

                if (result.HasValue)
                {
                    return SyntaxFactory.LiteralExpression(
                        SyntaxKind.NumericLiteralExpression,
                        SyntaxFactory.Literal(result.Value)
                    );
                }
            }
        }

        return rewritten;
    }
}

// Usage: Transform "2 + 3" into "5"
var folder = new ConstantFolder();
var optimized = folder.Visit(root);

Overriding Visit Methods

Override specific Visit methods for the node types you want to transform:
public class FunctionInliner : LuaSyntaxRewriter
{
    public override SyntaxNode? VisitFunctionCallExpression(
        FunctionCallExpressionSyntax node)
    {
        // Rewrite children first
        var rewritten = (FunctionCallExpressionSyntax?)base.VisitFunctionCallExpression(node);
        if (rewritten == null) return null;

        // Replace specific function calls
        if (rewritten.Expression is IdentifierNameSyntax id &&
            id.Name == "tostring")
        {
            // Replace tostring(x) with string.format("%s", x)
            return SyntaxFactory.FunctionCallExpression(
                SyntaxFactory.MemberAccessExpression(
                    SyntaxFactory.IdentifierName("string"),
                    SyntaxFactory.Token(SyntaxKind.DotToken),
                    SyntaxFactory.IdentifierName("format")
                ),
                // ... add arguments
            );
        }

        return rewritten;
    }
}
Always call base.VisitXxx(node) first to ensure child nodes are rewritten before you process the parent.

Preserving Trivia with WithTriviaFrom

When replacing nodes, preserve comments and formatting using WithTriviaFrom:
public class IdentifierRenamer : LuaSyntaxRewriter
{
    private readonly string _oldName;
    private readonly string _newName;

    public IdentifierRenamer(string oldName, string newName)
    {
        _oldName = oldName;
        _newName = newName;
    }

    public override SyntaxNode? VisitIdentifierName(
        IdentifierNameSyntax node)
    {
        if (node.Name == _oldName)
        {
            // Create new identifier with same trivia as original
            var newNode = SyntaxFactory.IdentifierName(_newName);
            return newNode.WithTriviaFrom(node);
        }
        return base.VisitIdentifierName(node);
    }
}

// Usage: Rename 'x' to 'y'
var renamer = new IdentifierRenamer("x", "y");
var renamed = renamer.Visit(root);
WithTriviaFrom copies both leading and trailing trivia, preserving:
  • Comments
  • Whitespace
  • Line breaks

Example: Variable Renamer

Here’s a complete rewriter that renames all occurrences of a variable:
using Loretta.CodeAnalysis;
using Loretta.CodeAnalysis.Lua;
using Loretta.CodeAnalysis.Lua.Syntax;

public class VariableRenamer : LuaSyntaxRewriter
{
    private readonly string _oldName;
    private readonly string _newName;

    public VariableRenamer(string oldName, string newName)
    {
        _oldName = oldName;
        _newName = newName;
    }

    public static SyntaxNode Rename(
        SyntaxNode node,
        string oldName,
        string newName)
    {
        var renamer = new VariableRenamer(oldName, newName);
        return renamer.Visit(node)!;
    }

    public override SyntaxNode? VisitIdentifierName(
        IdentifierNameSyntax node)
    {
        if (node.Name == _oldName)
        {
            return SyntaxFactory
                .IdentifierName(_newName)
                .WithTriviaFrom(node);
        }
        return base.VisitIdentifierName(node);
    }

    public override SyntaxNode? VisitLocalDeclarationName(
        LocalDeclarationNameSyntax node)
    {
        if (node.Name.Name == _oldName)
        {
            var newIdentifier = SyntaxFactory
                .IdentifierName(_newName)
                .WithTriviaFrom(node.Name);
            return node.WithName(newIdentifier);
        }
        return base.VisitLocalDeclarationName(node);
    }

    public override SyntaxNode? VisitNamedParameter(
        NamedParameterSyntax node)
    {
        if (node.Name.Text == _oldName)
        {
            var newName = SyntaxFactory
                .Identifier(_newName)
                .WithTriviaFrom(node.Name);
            return node.WithName(newName);
        }
        return base.VisitNamedParameter(node);
    }
}

// Usage
var code = @"
local x = 10
print(x)
x = x + 1
";

var syntaxTree = LuaSyntaxTree.ParseText(code);
var root = syntaxTree.GetRoot();

var newRoot = VariableRenamer.Rename(root, "x", "myValue");

Console.WriteLine(newRoot.ToFullString());
// Output:
// local myValue = 10
// print(myValue)
// myValue = myValue + 1

Example: Dead Code Remover

Remove unreachable code after return statements:
public class DeadCodeRemover : LuaSyntaxRewriter
{
    public override SyntaxNode? VisitStatementList(
        StatementListSyntax node)
    {
        // Rewrite children first
        var rewritten = (StatementListSyntax?)base.VisitStatementList(node);
        if (rewritten == null) return null;

        var statements = rewritten.Statements;
        var newStatements = new List<StatementSyntax>();
        var foundReturn = false;

        foreach (var statement in statements)
        {
            if (foundReturn)
            {
                // Skip - this is dead code
                continue;
            }

            newStatements.Add(statement);

            if (statement is ReturnStatementSyntax)
            {
                foundReturn = true;
            }
        }

        return rewritten.WithStatements(
            SyntaxFactory.List(newStatements)
        );
    }
}

Example: Adding Logging

Insert logging calls at the start of each function:
public class LoggingInjector : LuaSyntaxRewriter
{
    public override SyntaxNode? VisitFunctionDeclarationStatement(
        FunctionDeclarationStatementSyntax node)
    {
        var rewritten = (FunctionDeclarationStatementSyntax?)
            base.VisitFunctionDeclarationStatement(node);
        if (rewritten == null) return null;

        // Create logging call: print("Entering: functionName")
        var logCall = SyntaxFactory.ExpressionStatement(
            SyntaxFactory.FunctionCallExpression(
                SyntaxFactory.IdentifierName("print"),
                SyntaxFactory.FunctionArgumentList(
                    SyntaxFactory.SeparatedList<ExpressionSyntax>(new[] {
                        SyntaxFactory.LiteralExpression(
                            SyntaxKind.StringLiteralExpression,
                            SyntaxFactory.Literal(
                                $"Entering: {rewritten.Name.Name.Name}"
                            )
                        )
                    })
                )
            )
        ).NormalizeWhitespace();

        // Insert at the beginning of function body
        var newStatements = new List<StatementSyntax> { logCall };
        newStatements.AddRange(rewritten.Body.Statements);

        return rewritten.WithBody(
            rewritten.Body.WithStatements(
                SyntaxFactory.List(newStatements)
            )
        );
    }
}

Example: Converting Globals to Locals

From the tutorial - convert global function calls to local variables:
using System.Collections.Immutable;
using Loretta.CodeAnalysis;
using Loretta.CodeAnalysis.Lua;
using Loretta.CodeAnalysis.Lua.Syntax;

public class GlobalToLocalRewriter : LuaSyntaxRewriter
{
    private readonly ImmutableDictionary<string, IdentifierNameSyntax> _localNames;
    private readonly IEnumerable<IGrouping<string, FunctionCallExpressionSyntax>> _functionCalls;

    public static SyntaxNode Rewrite(
        IEnumerable<IGrouping<string, FunctionCallExpressionSyntax>> functionCalls,
        SyntaxNode node)
    {
        var rewriter = new GlobalToLocalRewriter(functionCalls);
        return rewriter.Visit(node)!;
    }

    private GlobalToLocalRewriter(
        IEnumerable<IGrouping<string, FunctionCallExpressionSyntax>> functionCalls)
    {
        _functionCalls = functionCalls;
        // Create deduplicated identifier name nodes
        _localNames = functionCalls.ToImmutableDictionary(
            g => g.Key,
            g => SyntaxFactory.IdentifierName(g.Key)
        );
    }

    public override SyntaxNode? VisitCompilationUnit(
        CompilationUnitSyntax node)
    {
        var statements = VisitList(node.Statements.Statements);

        // Create the list of names and values
        var names = _functionCalls.Select(g =>
            SyntaxFactory.LocalDeclarationName(_localNames[g.Key])
        );
        var values = _functionCalls.Select(g => g.First().Expression);

        // Create the local variable declaration
        var localDeclaration = SyntaxFactory
            .LocalVariableDeclarationStatement(
                SyntaxFactory.SeparatedList(names),
                SyntaxFactory.SeparatedList<ExpressionSyntax>(values)
            )
            .NormalizeWhitespace()
            .WithTrailingTrivia(
                SyntaxFactory.EndOfLine(Environment.NewLine)
            );

        // Insert at the start
        statements = statements.Insert(0, localDeclaration);

        var statementList = node.Statements.WithStatements(statements);
        return node.WithStatements(statementList);
    }

    public override SyntaxNode? VisitFunctionCallExpression(
        FunctionCallExpressionSyntax node)
    {
        foreach (var group in _functionCalls)
        {
            if (!group.Contains(node))
                continue;

            // Get the identifier node
            var nameNode = _localNames[group.Key];

            // Import trivia from old expression
            nameNode = nameNode.WithTriviaFrom(node.Expression);

            // Update the function argument(s)
            var argument = (FunctionArgumentSyntax)Visit(node.Argument)!;

            // Return the function call with updated expression
            return node.Update(nameNode, argument);
        }
        return base.VisitFunctionCallExpression(node);
    }
}

Removing Nodes

Return null or a node with SyntaxKind.None to remove:
public class CommentRemover : LuaSyntaxRewriter
{
    public CommentRemover() : base(visitIntoStructuredTrivia: true)
    {
    }

    public override SyntaxToken VisitToken(SyntaxToken token)
    {
        // Remove comment trivia
        var withoutLeading = token.WithLeadingTrivia(
            token.LeadingTrivia.Where(t =>
                !t.IsKind(SyntaxKind.SingleLineCommentTrivia) &&
                !t.IsKind(SyntaxKind.MultiLineCommentTrivia)
            )
        );

        return withoutLeading.WithTrailingTrivia(
            withoutLeading.TrailingTrivia.Where(t =>
                !t.IsKind(SyntaxKind.SingleLineCommentTrivia) &&
                !t.IsKind(SyntaxKind.MultiLineCommentTrivia)
            )
        );
    }
}

Working with Lists

Use helper methods to rewrite lists:
public class ParameterRewriter : LuaSyntaxRewriter
{
    public override SyntaxNode? VisitParameterList(
        ParameterListSyntax node)
    {
        // Rewrite the separated list of parameters
        var parameters = VisitList(node.Parameters);
        
        if (parameters != node.Parameters)
        {
            return node.WithParameters(parameters);
        }
        return node;
    }
}

Visiting Structured Trivia

Enable structured trivia visiting for complex scenarios:
public class TriviaRewriter : LuaSyntaxRewriter
{
    public TriviaRewriter() : base(visitIntoStructuredTrivia: true)
    {
    }

    public override SyntaxTrivia VisitTrivia(SyntaxTrivia trivia)
    {
        if (trivia.HasStructure)
        {
            // Process structured trivia
            var structure = trivia.GetStructure();
            var rewritten = Visit(structure);
            if (rewritten != structure)
            {
                return SyntaxFactory.Trivia((StructuredTriviaSyntax)rewritten);
            }
        }
        return base.VisitTrivia(trivia);
    }
}

Best Practices

Always Call Base

public override SyntaxNode? VisitBinaryExpression(
    BinaryExpressionSyntax node)
{
    // ALWAYS call base first to rewrite children
    var rewritten = (BinaryExpressionSyntax?)base.VisitBinaryExpression(node);
    if (rewritten == null) return null;

    // Then transform the node
    return TransformNode(rewritten);
}

Preserve Trivia

// Good - preserves formatting
var newNode = SyntaxFactory.IdentifierName("new")
    .WithTriviaFrom(oldNode);

// Bad - loses formatting
var newNode = SyntaxFactory.IdentifierName("new");

Use NormalizeWhitespace

// Format newly created nodes
var statement = SyntaxFactory
    .LocalVariableDeclarationStatement(names, values)
    .NormalizeWhitespace();

Next Steps

Build docs developers (and LLMs) love