ModelGraph class is the central data structure in hls4ml that represents the entire neural network model. It manages the graph of layers, handles transformations, and coordinates the compilation and synthesis process.
ModelGraph Overview
TheModelGraph is defined in hls4ml/model/graph.py and provides:
- Graph management - Adding, removing, and replacing nodes
- Configuration - Model-wide settings and precision configuration
- Flow application - Running optimization passes
- Code generation - Writing and compiling HLS code
- Simulation - Running predictions on compiled models
Class Definition
class ModelGraph(Serializable):
"""The ModelGraph represents the network that is being processed by hls4ml.
Args:
config (dict): The configuration dictionary
inputs (list, optional): The inputs to the model
outputs (list, optional): The outputs to the model
initial_index (int): Starting index for layer numbering
"""
def __init__(self, config, inputs=None, outputs=None, initial_index=0):
self.config = config
self.inputs = inputs
self.outputs = outputs
self.graph = OrderedDict() # Ordered dict of layer name -> Layer
self._applied_flows = [] # Track applied optimization flows
self.index = initial_index
self.output_vars = {} # Map of output names -> variables
self._top_function_lib = None
hls4ml/model/graph.py:401-419
Creating a ModelGraph
From Layer List
The most common way to create a model graph is from a list of layer dictionaries:@classmethod
def from_layer_list(cls, config_dict, layer_list, inputs=None, outputs=None):
"""Create a ModelGraph from a list of layer configurations."""
config = HLSConfig(config_dict)
# Determine inputs/outputs if not provided
input_layers = inputs if inputs else [layer_list[0]['name']]
output_layers = outputs if outputs else [layer_list[-1]['name']]
model = cls(config, input_layers, output_layers, initial_index)
model._make_graph(layer_list)
# Apply default optimization flows
for flow in model.config.flows:
model.apply_flow(flow)
return model
hls4ml/model/graph.py:422-460
Building the Graph
def _make_graph(self, layer_list):
"""Construct the graph from layer configurations."""
for layer in layer_list:
kind = layer['class_name']
name = layer['name']
inputs = layer.get('inputs', [])
outputs = layer.get('outputs', [])
# Handle input layers
if kind in ['InputLayer', 'Input']:
inputs = ['input']
# Default to previous layer output
elif len(inputs) == 0:
inputs = [next(reversed(self.graph), 'input')]
if len(outputs) == 0:
outputs = [name]
# Create and add node to graph
self.graph[name] = self.make_node(kind, name, layer, inputs, outputs)
hls4ml/model/graph.py:470-483
Graph Operations
Creating Nodes
def make_node(self, kind, name, attributes, inputs, outputs=None, initialize=True):
"""Make a new node not connected to the model graph.
Args:
kind (type or str): Type of node to add
name (str): Name of the node
attributes (dict): Initial attributes
inputs (list): List of inputs
outputs (list, optional): Named outputs
initialize (bool, optional): Call initialize() method
Returns:
Layer: The created node
"""
if isinstance(kind, str):
if kind not in layer_map:
raise Exception(f'Layer {kind} not found in registry.')
layer_cls = layer_map[kind]
else:
if kind not in layer_map.values():
raise Exception(f'Layer {kind} not found in registry.')
layer_cls = kind
# Backend may wrap the layer class
if self.config.backend is not None:
layer_cls = self.config.backend.create_layer_class(layer_cls)
node = layer_cls(self, name, attributes, inputs, outputs, initialize)
# Register output variables
for o in node.outputs:
out_var = node.get_output_variable(output_name=o)
if len(self.outputs) == 1 and o in self.outputs:
out_var.type.name = 'result_t'
self.output_vars[o] = out_var
return node
hls4ml/model/graph.py:536-578
Inserting Nodes
def insert_node(self, node, before=None, input_idx=0):
"""Insert a new node into the model graph.
Args:
node (Layer): Node to insert (created with make_node)
before (Layer, optional): The next node in sequence
input_idx (int, optional): Input index for multi-input nodes
"""
if len(node.inputs) > 1:
raise Exception('Cannot insert a node with more than one input.')
prev_node = node.get_input_node(node.inputs[0])
next_nodes = []
# Find all nodes that use prev_node's output
for x in self.graph.values():
overlap = [value for value in x.inputs if value in prev_node.outputs]
if overlap:
next_nodes.append(x)
# Determine where to insert
if before is None:
next_node = next((x for x in self.graph.values()
if x.inputs and x.inputs[0] in prev_node.outputs), None)
else:
if before not in next_nodes:
raise Exception(f'Cannot insert {node.name} before {before.name}')
next_node = before
# Update connections
if next_node is not None:
next_node.inputs[input_idx] = node.outputs[0]
else:
# Inserting at end - update model outputs
self.outputs = [node.outputs[0] if name == prev_node.outputs[0]
else name for name in self.outputs]
# Insert into ordered graph
new_graph = OrderedDict()
for k, v in self.graph.items():
new_graph[k] = v
if k == prev_node.name:
new_graph[node.name] = node
self.graph = new_graph
hls4ml/model/graph.py:580-628
Removing Nodes
def remove_node(self, node):
"""Remove a node from the graph.
By default, connects the previous node's output to the next node's input.
Only works for nodes with single input/output.
Args:
node (Layer): The node to remove
"""
inputs = [inp for inp in node.inputs if inp]
outputs = [outp for outp in node.outputs if outp]
if len(inputs) > 1 or len(outputs) > 1:
raise Exception('Cannot delete a node with multiple inputs/outputs')
if len(outputs) == 1 and len(inputs) == 1:
# Update model outputs if needed
if node.outputs[0] in self.outputs:
self.outputs = [inputs[0] if name == node.outputs[0]
else name for name in self.outputs]
# Validate shape compatibility
inp_var = node.get_input_variable()
out_var = node.get_output_variable()
assert np.prod(inp_var.shape) == np.prod(out_var.shape), \
f'Shape mismatch: {inp_var.shape} -> {out_var.shape}'
# Connect next nodes to previous node
next_nodes = [x for x in self.graph.values()
if node.outputs[0] in x.inputs]
for next_node in next_nodes:
for i, nxt_inp in enumerate(next_node.inputs):
if outputs[0] == nxt_inp:
next_node.inputs[i] = inputs[0]
# Remove from graph
del self.output_vars[node.outputs[0]]
del self.graph[node.name]
hls4ml/model/graph.py:630-674
Replacing Nodes
def replace_node(self, old_node, new_node):
"""Replace an existing node with a new one.
Args:
old_node (Layer): The node to replace
new_node (Layer): The new node
"""
assert len(new_node.inputs) == len(old_node.inputs), \
f'Input count mismatch: {new_node.name} vs {old_node.name}'
assert len(new_node.outputs) == len(old_node.outputs), \
f'Output count mismatch: {new_node.name} vs {old_node.name}'
# Build replacement mapping
repl = {old_name: new_name
for old_name, new_name in zip(old_node.outputs, new_node.outputs)}
repl.update({old_name: new_name
for old_name, new_name in zip(old_node.inputs, new_node.inputs)})
# Update model outputs
for old_output in old_node.outputs:
if old_output in self.outputs:
new_output = repl[old_output]
self.outputs = [new_output if name == old_output
else name for name in self.outputs]
# Update all node references
for node in self.graph.values():
for i, n in enumerate(node.inputs):
if n in repl:
node.inputs[i] = repl[n]
for i, n in enumerate(node.outputs):
if n in repl:
node.outputs[i] = repl[n]
# Replace in graph
self.graph = OrderedDict(
(new_node.name, new_node) if k == old_node.name else (k, v)
for k, v in self.graph.items()
)
hls4ml/model/graph.py:676-708
Splitting Nodes
def split_node(self, old_node, new_node1, new_node2):
"""Replace a node with two nodes in sequence.
Args:
old_node (Layer): The node to replace
new_node1 (Layer): The first new node
new_node2 (Layer): The second new node
"""
assert len(new_node1.inputs) == len(old_node.inputs)
assert len(new_node2.outputs) == len(old_node.outputs)
# Build replacement mapping
repl = {old_name: new_name
for old_name, new_name in zip(old_node.outputs, new_node2.outputs)}
repl.update({old_name: new_name
for old_name, new_name in zip(old_node.inputs, new_node1.inputs)})
# Update references (similar to replace_node)
# ...
# Insert both nodes in sequence
new_graph = OrderedDict()
for key, value in self.graph.items():
if key == old_node.name:
new_graph[new_node1.name] = new_node1
new_graph[new_node2.name] = new_node2
else:
new_graph[key] = value
self.graph = new_graph
hls4ml/model/graph.py:710-750
Querying the Graph
Accessing Layers
# Get all layers in order
layers = model.get_layers()
# Access specific layer by name
layer = model.graph['layer_name']
# Get input variables
input_vars = model.get_input_variables()
# Get output variables
output_vars = model.get_output_variables()
# Get all weight variables
weight_vars = model.get_weight_variables()
hls4ml/model/graph.py:752-785
Getting Variables
def get_input_variables(self):
"""Get all input variables for the model."""
variables = []
for inp in self.inputs:
variables.append(self.graph[inp].get_output_variable())
return variables
def get_output_variables(self):
"""Get all output variables for the model."""
variables = []
for out in self.outputs:
variables.append(self.output_vars[out])
return variables
def get_layer_output_variable(self, output_name):
"""Get a specific output variable by name."""
return self.output_vars.get(output_name, None)
hls4ml/model/graph.py:759-777
Configuration Management
HLSConfig Class
TheHLSConfig class manages model-wide configuration:
class HLSConfig(Serializable):
"""Configuration class for the ModelGraph.
Args:
config (dict): The configuration dictionary
"""
def __init__(self, config):
self.config = config
self.backend = get_backend(self.config.get('Backend', 'Vivado'))
# Precision configuration at different levels
self.model_precision = {}
self.layer_type_precision = {}
self.layer_name_precision = {}
# Reuse factor configuration
self.model_rf = None
self.layer_type_rf = {}
self.layer_name_rf = {}
# Strategy configuration
self.model_strategy = 'latency'
self.layer_type_strategy = {}
self.layer_name_strategy = {}
# Other configurations...
self._parse_hls_config()
hls4ml/model/graph.py:23-73
Getting Precision
Configuration lookup follows a hierarchy: layer name > layer type > model default:def get_precision(self, layer, var='default'):
"""Get precision for a layer's variable.
Lookup order:
1. Layer name + variable name
2. Layer name + 'default'
3. Layer type + variable name
4. Layer type + 'default'
5. Model + variable name
6. Model + 'default'
"""
precision = self.layer_name_precision.get(layer.name.lower() + '_' + var)
type_name = layer.name.lower() + '_' + var + '_t'
if precision is None:
precision = self.layer_name_precision.get(layer.name.lower() + '_default')
if precision is None:
precision = self.layer_type_precision.get(layer.class_name.lower() + '_' + var)
type_name = layer.class_name + '_' + var + '_t'
if precision is None:
precision = self.layer_type_precision.get(layer.class_name.lower() + '_default')
type_name = layer.class_name + '_default_t'
if precision is None:
precision = self.model_precision.get(var)
type_name = var + '_default_t'
if precision is None:
precision = self.model_precision.get('default')
type_name = 'model_default_t'
if precision is None:
raise Exception(f'No precision for {layer.name}->{var} found')
precision = self.backend.convert_precision_string(precision)
return (precision, type_name)
hls4ml/model/graph.py:127-154
Compilation and Execution
Writing HLS Code
def write(self):
"""Write the generated project to disk.
Converts the model to C++ and writes files to the output directory.
"""
self.config.backend.write(self)
hls4ml/model/graph.py:787-794
Compiling
def compile(self):
"""Compile the generated project and link the library.
Call this to use predict() functionality for simulation.
"""
self.write()
self._compile()
def _compile(self):
lib_name = self.config.backend.compile(self)
# Load the compiled library
self._top_function_lib = ctypes.cdll.LoadLibrary(lib_name)
hls4ml/model/graph.py:796-821
Running Predictions
def predict(self, x, *args, **kwargs):
"""Run prediction on input data.
Args:
x: Input data (numpy array or list of arrays)
Returns:
Predictions as numpy array(s)
"""
backend = self.config.backend
if hasattr(backend, 'predict') and callable(backend.predict):
return backend.predict(self, x, *args, **kwargs)
return self._predict(x)
hls4ml/model/graph.py:909-915
Building with HLS
def build(self, **kwargs):
"""Build the project using HLS compiler.
Backend-specific arguments can be passed via kwargs.
"""
if not os.path.exists(self.config.get_output_dir()):
self.write()
return self.config.backend.build(self, **kwargs)
hls4ml/model/graph.py:1001-1010
Serialization
Saving and Loading
def save(self, file_path):
"""Save the ModelGraph to a file.
Args:
file_path (str): Path to save the model
"""
from hls4ml.utils.serialization import serialize_model
serialize_model(self, file_path)
def serialize(self):
"""Serialize the model graph state."""
applied_flows = []
for flow_group in self._applied_flows:
flow_cpy = {}
for flow_name, opt_set in flow_group.items():
flow_cpy[flow_name] = list(opt_set)
applied_flows.append(flow_cpy)
state = {
'inputs': self.inputs.copy(),
'outputs': self.outputs.copy(),
'index': self.index,
'applied_flows': applied_flows,
}
return state
hls4ml/model/graph.py:1012-1026
MultiModelGraph
For advanced use cases, hls4ml supports splitting a model into multiple subgraphs:class MultiModelGraph:
"""Stitched model from pre-optimized subgraphs."""
def __init__(self, graphs: list[ModelGraph]):
self.graphs = graphs
# Initialize configuration, bind methods, etc.
@classmethod
def from_model_graph(cls, base_model: ModelGraph,
split_before_layers: list[str]):
"""Split a ModelGraph at specified layers."""
# Validate split points
# Create subgraphs
# Apply flows to each subgraph
return cls(subgraphs)
hls4ml/model/graph.py:1047-1113
Related Documentation
Intermediate Representation
Learn about Layer classes and attributes
Optimization Flows
Understand optimization passes and flows
