Documentation Index
Fetch the complete documentation index at: https://mintlify.com/apache/wayang/llms.txt
Use this file to discover all available pages before exploring further.
Apache Wayang’s query optimizer selects the best execution plan by comparing cost estimates across candidate plans. By default it uses a formula-based cost model, but the EstimatableCost interface lets you swap in any cost function you want — including one backed by a trained machine learning model. This page explains the interface, shows a complete ML cost implementation using ONNX Runtime, and describes how the encoding and model inference pipeline works.
The ML cost model feature is intended for researchers and advanced users exploring custom optimizer strategies. It requires a pre-trained ONNX model, access to the wayang-ml module, and familiarity with the Wayang optimizer internals. Most production users should rely on the default cost model.
The EstimatableCost interface
org.apache.wayang.core.optimizer.costs.EstimatableCost is the contract the optimizer uses to rank candidate PlanImplementation objects. You must implement all six methods:
public interface EstimatableCost {
// Factory so the optimizer can create fresh instances per optimization pass
EstimatableCostFactory getFactory();
// Choose the best plan from a set of candidates
PlanImplementation pickBestExecutionPlan(
Collection<PlanImplementation> executionPlans,
ExecutionPlan existingPlan,
Set<Channel> openChannels,
Set<ExecutionStage> executedStages);
// Probabilistic cost interval (lower, upper, confidence)
ProbabilisticDoubleInterval getEstimate(
PlanImplementation plan, boolean isOverheadIncluded);
ProbabilisticDoubleInterval getParallelEstimate(
PlanImplementation plan, boolean isOverheadIncluded);
// Squashed (single scalar) cost — used for direct comparisons
double getSquashedEstimate(
PlanImplementation plan, boolean isOverheadIncluded);
double getSquashedParallelEstimate(
PlanImplementation plan, boolean isOverheadIncluded);
// Per-operator/junction breakdown for detailed profiling
Tuple<List<ProbabilisticDoubleInterval>, List<Double>>
getParallelOperatorJunctionAllCostEstimate(
PlanImplementation plan, Operator operator);
}
The simplest custom cost model uses a constant or formula. The ML variant described here calls an ONNX model inside each method.
Enabling a custom cost model
Instantiate your implementation and pass it to Configuration before constructing WayangContext:
Configuration config = new Configuration();
config.setCostModel(new CustomEstimatableCost());
WayangContext wayangContext = new WayangContext(config)
.withPlugin(Java.basicPlugin())
.withPlugin(Spark.basicPlugin());
The wayang-ml module and MachineLearning plugin
The wayang-plugins/wayang-ml module provides:
MLContext — a WayangContext subclass that executes plans and optionally logs experience data (encoded plan + actual runtime) for offline model training.
OrtMLModel — a singleton that loads and runs an ONNX model using the ONNX Runtime Java binding.
OneHotEncoder — encodes a PlanImplementation into a fixed-length long[] vector.
TreeEncoder — encodes the full plan tree for the experience-logging pipeline.
Add the plugin to your WayangContext the same way as any other plugin:
import org.apache.wayang.ml.MachineLearning;
WayangContext wayang = new WayangContext(config)
.withPlugin(Java.basicPlugin())
.withPlugin(Spark.basicPlugin())
.withPlugin(MachineLearning.plugin());
ML configuration properties
| Property | Default | Description |
|---|
wayang.ml.model.file | /wayang-plugins/wayang-ml/src/main/resources/linear_model.onnx | Absolute path to the pre-trained ONNX model file. |
wayang.ml.experience.enabled | false | When true, MLContext logs encoded plans and execution times after each job. |
wayang.ml.experience.file | /var/www/html/data/experience/experience-vae.txt | File path for experience log entries. |
wayang.ml.optimizations.file | /var/www/html/data/optmizations.txt | File path for per-inference timing data. |
# wayang.properties
wayang.ml.model.file = /opt/models/wayang_cost_model.onnx
wayang.ml.experience.enabled = true
wayang.ml.experience.file = /var/log/wayang/experience.txt
Complete MLCost implementation
The following example is the full ML-backed EstimatableCost implementation from the upstream source guide. It demonstrates the interface structure using OneHotEncoder.encode(plan) (which returns a long[]) with the runModel(long[]) stub overload. For a production deployment, replace OneHotEncoder + runModel(long[]) with the TreeEncoder → OrtTensorEncoder → runModel(Tuple) pipeline shown in the OrtMLModel section below.
import org.apache.wayang.core.optimizer.costs.EstimatableCost;
import org.apache.wayang.core.optimizer.costs.EstimatableCostFactory;
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval;
import org.apache.wayang.core.optimizer.enumeration.PlanImplementation;
import org.apache.wayang.core.plan.executionplan.ExecutionPlan;
import org.apache.wayang.core.plan.executionplan.ExecutionStage;
import org.apache.wayang.core.plan.executionplan.Channel;
import org.apache.wayang.core.plan.wayangplan.Operator;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.api.exception.WayangException;
import org.apache.wayang.ml.encoding.OneHotEncoder;
import org.apache.wayang.ml.encoding.OrtMLModel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
public class MLCost implements EstimatableCost {
public EstimatableCostFactory getFactory() {
return new Factory();
}
public static class Factory implements EstimatableCostFactory {
@Override
public EstimatableCost makeCost() {
return new MLCost();
}
}
@Override
public ProbabilisticDoubleInterval getEstimate(
PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return ProbabilisticDoubleInterval.ofExactly(
model.runModel(OneHotEncoder.encode(plan))
);
} catch (Exception e) {
return ProbabilisticDoubleInterval.zero;
}
}
@Override
public ProbabilisticDoubleInterval getParallelEstimate(
PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return ProbabilisticDoubleInterval.ofExactly(
model.runModel(OneHotEncoder.encode(plan))
);
} catch (Exception e) {
return ProbabilisticDoubleInterval.zero;
}
}
/** Returns a squashed cost estimate. */
@Override
public double getSquashedEstimate(
PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return model.runModel(OneHotEncoder.encode(plan));
} catch (Exception e) {
return 0;
}
}
@Override
public double getSquashedParallelEstimate(
PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return model.runModel(OneHotEncoder.encode(plan));
} catch (Exception e) {
return 0;
}
}
@Override
public Tuple<List<ProbabilisticDoubleInterval>, List<Double>>
getParallelOperatorJunctionAllCostEstimate(
PlanImplementation plan, Operator operator) {
List<ProbabilisticDoubleInterval> intervalList = new ArrayList<>();
List<Double> doubleList = new ArrayList<>();
intervalList.add(this.getEstimate(plan, true));
doubleList.add(this.getSquashedEstimate(plan, true));
return new Tuple<>(intervalList, doubleList);
}
@Override
public PlanImplementation pickBestExecutionPlan(
Collection<PlanImplementation> executionPlans,
ExecutionPlan existingPlan,
Set<Channel> openChannels,
Set<ExecutionStage> executedStages) {
return executionPlans.stream()
.reduce((p1, p2) -> {
final double t1 = p1.getSquashedCostEstimate();
final double t2 = p2.getSquashedCostEstimate();
return t1 < t2 ? p1 : p2;
})
.orElseThrow(() ->
new WayangException("Could not find an execution plan."));
}
}
OrtMLModel: ONNX runtime inference
OrtMLModel is a singleton that manages an ONNX Runtime OrtSession. It loads the model from the path specified by wayang.ml.model.file on first call to getInstance(), and reuses the session for all subsequent calls. The session is configured with 16 inter-op and 16 intra-op threads and deterministic compute enabled.
OrtMLModel exposes two runModel overloads:
runModel(long[] encoded) — a stub overload used in simple examples; it returns 0 and does not run ONNX inference.
runModel(Tuple<ArrayList<long[][]>, ArrayList<long[][]>> input) — the production overload. It takes a tree-structured input (values and index arrays) produced by OrtTensorEncoder.encode(TreeNode) and returns the model’s double cost prediction.
The DefaultPointwiseCost class in wayang-ml shows the full production path:
import org.apache.wayang.ml.encoding.OrtMLModel;
import org.apache.wayang.ml.encoding.OrtTensorEncoder;
import org.apache.wayang.ml.encoding.TreeEncoder;
import org.apache.wayang.ml.encoding.TreeNode;
import org.apache.wayang.ml.encoding.OneHotMappings;
import org.apache.wayang.core.util.Tuple;
import java.util.ArrayList;
// Set up encoder (once per application)
TreeEncoder encoder = new TreeEncoder(new OneHotMappings());
// Encode the plan into a tree, then into tensor format
TreeNode encodedTree = encoder.encode(plan);
Tuple<ArrayList<long[][]>, ArrayList<long[][]>> tensor =
OrtTensorEncoder.encode(encodedTree);
// Run inference
OrtMLModel model = OrtMLModel.getInstance(config);
double predictedCost = model.runModel(tensor); // production overload
// Close when the application exits
model.closeSession();
OrtMLModel.closeSession() releases the underlying native ONNX Runtime resources. Call it once when your application shuts down, not after each inference. Failing to call it will cause a small native memory leak per process lifetime.
OneHotEncoder: plan encoding
OneHotEncoder.encode(PlanImplementation) produces a long[] vector that represents the full execution plan. The encoding captures:
| Segment | Contents |
|---|
| Topologies | Replicator count, pipeline count, junction count, loop count |
| Operators | Per-operator-class: instance counts per platform, UDF complexity sum, input/output cardinality sums |
| Data movement | Per-conversion-operator: instance counts per platform, input/output cardinality sums |
| Dataset | A fixed dataset marker (100L) |
The platform and operator position mappings come from OneHotMappings, which must be initialized before encoding. This is handled automatically when you use MLContext or call MachineLearning.plugin().
Training your own model
To collect training data, enable experience logging via MLContext:
Configuration config = new Configuration();
config.setProperty("wayang.ml.experience.enabled", "true");
config.setProperty("wayang.ml.experience.file", "/data/experience.txt");
config.setProperty("wayang.ml.model.file", "/models/my_model.onnx");
MLContext mlContext = new MLContext(config);
mlContext.withPlugin(Java.basicPlugin())
.withPlugin(Spark.basicPlugin());
// Run jobs normally — each execution appends an experience record
mlContext.execute(plan, ReflectionUtils.getDeclaringJar(MyJob.class));
Each experience record written to the file has the format:
<encoded_original_plan>:<encoded_plan_with_choices>:<execution_time_ms>
Train a regression model on these records and export it to ONNX format using your ML framework of choice (scikit-learn, PyTorch, etc.). Then point wayang.ml.model.file at the exported .onnx file.