How to Run a Keras Model in the Browser with Keras.js

By Andre Perunicic | September 12, 2017

Introduction

This article explains how to export weights from a Keras model and import those weights in Keras.js, a JavaScript framework for running pre-trained neural networks in the browser. While the Keras.js GitHub page sketches out how to import weights, this post makes all the steps explicit and deals with a few gotchas. This post covers:

  1. Installing the right version of Keras, training your network, and exporting the weights.
  2. Setting up Keras.js and importing the weights.
  3. Testing the network by peforming a prediction in the browser.

Exporting the Weights

To make things concrete I will use an example Convolutional Neural Network (CNN) for classifying handwritten digits. If you’re writing your own network make sure to only use layers explicitly supported by Keras.js. You can see a list of supported layers on the Keras.js README.

The only thing you need to be particuarly careful about before training your network is to install supported versions of the required pacakges. Start by creating a virual environment

mkdir keras-weight-transfer
cd keras-weight-transfer

virtualenv env
. env/bin/activate

and installing the necessary Python packages with

pip install tensorflow==1.1.0 keras==2.0.4 h5py

Make sure that the installation went well by verifying that

python -c "from keras import backend"

outputs Using TensorFlow backend.

You can then go ahead and train your network. If you’re playing along using the reference CNN and don’t want to wait through 12 epochs covering the full 60000, you can get a less useful but functional network by restricting the training and test datasets to just 1280 and 512 images, respectively:

epochs = 1  # Overwrites previous value.

x_train = x_train[:1280]
y_train = y_train[:1280]

x_test = x_test[:512]
y_test = y_test[:512]

After the model has completed training, export the weights with

model.save_weights('model.hdf5')
with open('model.json', 'w') as f:
    f.write(model.to_json())

Before these are usable with Keras.js you have to separate the weights from other data in the exported HDF5 library. Keras.js devs provided a script for doing so. Download the script with

wget https://raw.githubusercontent.com/transcranial/keras-js/master/encoder.py

and run it with

python encoder.py ./model.hdf5

This produces model.json, model_metadata.json, and model_weights.buf files that you’ll load from JavaScript.

Get Help from Our Data Experts

Looking to implement your own data solution? Our experts are here to help you source and process data from around the web and meet your specific needs. Whether for lead-generation, competitor research, or powering your core product, we have a solution that's right for you.

Get Started Now

Running a Neural Network in the Browser

We will now create a tiny JavaScript application that loads the previously saved model and weights. Create frontent code and distribution folders

mkdir -p frontend/dist

You’ll want to copy the extracted model data files into the frontend/dist directory:

cp model.json frontend/dist/
cp model_metadata.json frontend/dist/
cp model_weights.buf frontend/dist/

The Webpack Setup

We will write the code in ES6 and prepare it for the browser using the webpack JavaScript code and asset bundler. So, install webpack and the webpack development server via npm by running

npm install webpack -g
npm install webpack-dev-server -g

You will also want to start a node project and install the required packages

npm init
npm install --save keras-js url-loader

I will only describe the most basic webpack configuration necessary to accomplish the task at hand. In a text editor, create a file webpack.config.js inside the frontend/ directory with the following contents:

var path = require('path')

module.exports = {
  // The script from which dependencies are to be loaded.
  entry: './entry.js',
  // Name and location of the output bundle.
  output: {
    path: path.join(__dirname, 'dist'),
    filename: 'bundle.js'
  },
  // Add source maps in order to have more meaningful errors.
  devtool: 'source-map',
  // Let the development server know where to serve the files from.
  devServer: {
    contentBase: 'dist'
  },
  // Fix for the 'fs' module issue described below.
  node: {
    fs: 'empty'
  },
  // Fix for the GLSL issue described below.
  module: {
    loaders: [
      { test: /\.(glsl)$/, loader: 'url-loader'}
    ]
  }
};

If you remove the two fixes above, you will get a bunch of errors when compiling the bundle with webpack. The first looks like this:

ERROR in ./node_modules/keras-js/lib/Model.js
Module not found: Error: Can't resolve 'fs' in '~/keras-weight-transfer/fronterend/node_modules/keras-js/lib'
 @ ./node_modules/keras-js/lib/Model.js 211:50-63
 @ ./node_modules/keras-js/lib/index.js
 @ ./entry.js

The fs module is part of Node.js and is probably used by Keras.js to load data files locally. Since we are interested in running the code as a web app we don’t actually need this module. Webpack has a built-in way to deal with Node.js-specific modules which means that all we had to do is add the following section to Webpack’s config:

node: {
  fs: 'empty'
}

The second error might look similar to

ERROR in ./node_modules/keras-js/lib/ext/convolutional/conv2d.glsl
Module parse failed: ~/keras-weight-transfer/fronterend/node_modules/keras-js/lib/ext/convolutional/conv2d.glsl Unexpected token (33:10)
You may need an appropriate loader to handle this file type.
| // SOFTWARE.
|
| precision highp float;
|
| varying vec2 outTex;
 @ ./node_modules/keras-js/lib/ext/convolutional/WebGLConv2D.js 30:50-74
 @ ./node_modules/keras-js/lib/layers/convolutional/Conv2D.js
 @ ./node_modules/keras-js/lib/layers/convolutional/index.js
 @ ./node_modules/keras-js/lib/layers/index.js
 @ ./node_modules/keras-js/lib/index.js
 @ ./entry.js

The solution this time was to add url-loader to the Webpack config. You may also find this related bug report helpful, as well as the webpack.config.js file in the Keras.js demos folder.

The HTML

I will use a very basic HTML file placed directly in the frontend/dist/index.html distribution directory, which is generally not a good idea but works for this demonstration.

<html>
  <head>
    <meta charset="utf-8">
    <script type="text/javascript" src="bundle.js" charset="utf-8"></script>
  </head>
  <body>
    Predicting...
  </body>
</html>

The JavaScript

The actual model loading code will live inside frontend/entry.js:

import { Model } from 'keras-js'

const model = new Model({
  filepaths: {
    model:  'model.json',
    weights: 'model_weights.buf',
    metadata: 'model_metadata.json'
  },
  gpu: true
})

With the sources in place, compile the bundle with

cd frontend
webpack --color --progress --watch

To view the app, run

webpack-dev-server

and open the indicated address (likely http://localhost:8080) in a web browser.

Using the Neural Network in the Browser

The loading now works, but it doesn’t actually do anything. Here I describe how to export a single data sample and do a prediction from the browser.

Back in Python, take a random MNIST sample (in this case x_train[135]) and visualize it:

from matplotlib import pyplot as plt

plt.figure()
plt.imshow(x_train[135].reshape((28, 28)), cmap=plt.cm.binary)
plt.show()

This should display the following image.

Sample from MNIST

Plotting from a Virtual Environment

If the image isn’t showing up, you may need to customize your matplotlib backend to work within a virtualenv. On macOS, for example, you could do

echo "backend: TkAgg" >> ~/.matplotlib/matplotlibrc

but results may vary on your platform. I’ve had luck with the Agg backend on Linux, for instance.

The easiest way to transfer data to JavaScript is to just print the sample as a list of floating point numbers. That is, save or copy the output of the following code

sample = list(x_train[135].reshape((28*28,)))
print(sample)

and paste this array into frontend/data.js

export const sample = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4509804, 0.47450981, 0.9137255, 0.85490197, 0.47450981, 0.47450981, 0.47450981, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.22352941, 0.94509804, 0.98431373, 0.98823529, 0.98823529, 0.98823529, ...]0.98823529, 0.98823529, 0.98823529, 0.93333334, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05882353, 0.36862746, 0.67058825, 0.94117647, 0.99215686, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.46666667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16078432, 0.98823529, 0.98823529, 0.98823529, 0.99215686, 0.85490197, 0.67450982, 0.67450982, 0.53333336, 0.15294118, 0.72549021, 0.98823529, 0.98823529, 0.46666667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11764706, 0.72549021, 0.72549021, 0.72549021, 0.20784314, 0.12156863, 0.0, 0.0, 0.0, 0.05882353, 0.76078433, 0.98823529, 0.98823529, 0.46666667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19607843, 0.78431374, 0.98823529, 0.98823529, 0.97647059, 0.36862746, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13333334, 0.65098041, 0.98823529, 0.98823529, 0.9137255, 0.29411766, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.047058824, 0.3019608, 0.93333334, 0.98823529, 0.98823529, 0.83137256, 0.32549021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.031372551, 0.054901961, 0.57647061, 0.74901962, 0.98823529, 0.98823529, 0.97254902, 0.82352942, 0.12941177, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.57254905, 0.98823529, 0.99215686, 0.98823529, 0.98823529, 0.98823529, 0.72549021, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.57647061, 0.99215686, 1.0, 0.99215686, 0.99215686, 0.99215686, 0.85490197, 0.37254903, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.33333334, 0.56862748, 0.57254905, 0.56862748, 0.94509804, 0.98823529, 0.98823529, 0.97647059, 0.29803923, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.28627452, 0.92941177, 0.98823529, 0.98823529, 0.30980393, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.83529413, 0.98823529, 0.98823529, 0.30980393, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.098039217, 0.87450981, 0.98823529, 0.98823529, 0.30980393, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11372549, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.090196081, 0.77254903, 0.98823529, 0.98823529, 0.98823529, 0.30980393, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.078431375, 0.65490198, 0.84313726, 0.51372552, 0.11764706, 0.0, 0.0, 0.0, 0.0, 0.086274512, 0.16078432, 0.78431374, 0.98823529, 0.98823529, 0.94509804, 0.72156864, 0.098039217, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.47450981, 0.98823529, 0.98823529, 0.98823529, 0.89019608, 0.627451, 0.627451, 0.627451, 0.627451, 0.81568629, 0.99215686, 0.98823529, 0.98823529, 0.89803922, 0.3764706, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.63137257, 0.98039216, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.99215686, 0.93725491, 0.74117649, 0.06666667, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.41960785, 0.63137257, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.98823529, 0.47058824, 0.023529412, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

In a real app, you’d load the data dynamically, but this is fine for a small demo. Now, at the top of entry.js add

import { sample } from './data'

and below the model loading code already there, add

model.ready()
  .then(() => {
    return model.predict({
      'input': new Float32Array(sample)
    })
  })
  .then(outputData => {
    const predictions = outputData['output']
    let max = -1;
    let digit = null;
    for (let i in predictions) {
      let probability = predictions[i];
      if (probability > max) {
        max = probability;
        digit = i;
      }
    }
    document.write(
      "Predicted digit " + digit + " with probability " + max.toFixed(3) + "."
    )
    console.log(outputData)

  })
  .catch(err => {
    console.log(err)
  })

This is a promise that will replace the text inside the document with the class prediction and log the probabilities for each class in the console. Run webpack --progress --color --watch to keep updating the bundle and webpack-dev-server to serve the dist directory. Opening up http://localhost:8080 should display something like the following.

Predicted digit and prediction probability

As you can see, this Keras.js network correctly classified the above digit as 3. I hope you found this tutorial on getting started with Keras.js useful. Please don’t hesitate to get in touch with us if you need help training and deploying machine learning models or sourcing data from the web.