Skip to main content
The predict() function performs skin cancer classification on the selected image using the loaded TensorFlow.js model.

Function signature

async function predict()
This is an asynchronous function. Ensure the model is loaded via loadModel() and an image is selected via onFileSelected() before calling this function.

Parameters

This function takes no parameters. It operates on the global imgtag element that contains the selected image.

Return value

prediction
Array<number>
An array of confidence scores for each of the 7 skin cancer classes. The function doesn’t directly return this value but uses it internally to update the UI.

Classification classes

The model classifies images into 7 categories:
  1. Actinic Keratoses
  2. Basal Cell Carcinoma
  3. Benign Keratoses
  4. Dermatofibroma
  5. Melanoma
  6. Melanocytic Nevus
  7. Vascular Lesion

How it works

The prediction process follows these steps:
  1. Validates that an image has been selected
  2. Converts the image to a TensorFlow.js tensor
  3. Preprocesses the image (resize to 75x100, convert to float, expand dimensions)
  4. Runs the model prediction
  5. Identifies the class with highest confidence
  6. Updates the UI with the prediction and confidence percentage

Image preprocessing

The function performs the following preprocessing steps:
  • Converts the HTML image to pixels using tf.browser.fromPixels()
  • Resizes to [75, 100] using nearest neighbor interpolation
  • Converts to float tensor
  • Expands dimensions to match model input shape
let tensorImg = tf.browser.fromPixels(imgtag)
                .resizeNearestNeighbor([75, 100])
                .toFloat().expandDims();

Code example

<a onclick="predict()" class="waves-effect waves-light btn">
  <i class="left material-icons">fiber_smart_record</i>
  Classify Image
</a>

Implementation

Here’s the complete implementation from the source code:
async function predict(){

    try {
        if(imgtag.src == ""){
            alert("Select an Image to Classify")
            return
        }

        let tensorImg = tf.browser.fromPixels(imgtag)
                        .resizeNearestNeighbor([75, 100])
                        .toFloat().expandDims();
        
        model.predict(tensorImg).data().then(
    function (prediction){
        let predicted_class = prediction.indexOf(Math.max(...prediction))


        console.log(classes[predicted_class])
        console.log(prediction) 

        prediction_text.innerHTML = classes[predicted_class]
        probability_text.innerHTML = Math.round(prediction[predicted_class] * 100) + "% Confidence"
    }
)
        
    }catch(error){
        alert("Error Classifying Image")
    }
}

Error handling

The function includes comprehensive error handling:

No image selected

if(imgtag.src == ""){
    alert("Select an Image to Classify")
    return
}

Classification errors

catch(error){
    alert("Error Classifying Image")
}
Always ensure the model is loaded before calling predict(). Attempting to predict without a loaded model will trigger the error handler.

Required global variables

The function depends on these global variables:
var model = null  // The loaded TensorFlow.js model
var classes = ['Actinic Keratoses', 'Basal Cell Carcinoma', 'Benign Keratoses', 
               'Dermatofibroma', 'Melanoma', 'Melanocytic Nevus', 'Vascular Lesion']
var imgtag  // HTML image element
var prediction_text  // HTML element for displaying class name
var probability_text  // HTML element for displaying confidence

Output format

The function updates two HTML elements with the results:
  • prediction_text - Displays the predicted class name (e.g., “Melanoma”)
  • probability_text - Displays the confidence percentage (e.g., “87% Confidence”)

Console output

The function also logs prediction details to the console:
console.log(classes[predicted_class])  // Class name
console.log(prediction)  // Full prediction array
This is useful for debugging and understanding the model’s confidence across all classes.
The confidence percentage is rounded to the nearest whole number using Math.round().

Build docs developers (and LLMs) love