By Andre Perunicic | September 12, 2017
This article explains how to export a pre-trained Keras model written in Python and use it in the browser with Keras.js. The main difficulty lies in choosing compatible versions of the packages involved and preparing the data, so I’ve prepared a fully worked out example that goes from training the model to performing a prediction in the browser. You can find the working end-result in Intoli’s article materials repository, but do read on if you’d like just the highlights.
Exporting the Weights
To make things concrete I will use a simple Convolutional Neural Network (CNN) which classifies handwritten digits. If you’re writing your own network make sure to only use layers explicitly supported by Keras.js.
You need to be particularly careful about using compatible versions of the relevant packages. These are easily managed from a virtualenv, so create and activate one:
mkdir -p keras-weight-transfer/neural-net
cd keras-weight-transfer/neural-net
virtualenv env
. env/bin/activate
Then, install the necessary Python packages with
pip install tensorflow==1.6.0 Keras==2.1.2 h5py==2.7.1
Make sure that the installation went well by verifying that
python -c "from keras import backend"
outputs Using TensorFlow backend
.
For the purposes of this article we’ll use a modified Keras example CNN which you can download to ./mnist-cnn.py
by running
curl https://raw.githubusercontent.com/intoli/intoli-article-materials/master/articles/keras-weight-transfer/neural-net/mnist-cnn.py -o mnist-cnn.py
This version differs from the original example provided by Keras in two respects. First, the data and training time are restricted in order to speed things along for this tutorial:
epochs = 1
x_train = x_train[:1280]
y_train = y_train[:1280]
x_test = x_test[:512]
y_test = y_test[:512]
Second, after the model is trained, the weights and model data are saved to a file named model.h5
with:
model.save('model.h5')
Run the script to train the network and save the model with:
python mnist-cnn.py
Before the model contained in the model.h5
HDF5 file is usable with Keras.js, it has to be passed through a compatible verion of Keras.js’s encoder.py
script.
Run
curl https://raw.githubusercontent.com/transcranial/keras-js/a5e6d2cc330ec8d979310bd17a47f07882fac778/python/encoder.py -o encoder.py
curl https://raw.githubusercontent.com/transcranial/keras-js/a5e6d2cc330ec8d979310bd17a47f07882fac778/python/model_pb2.py -o model_pb2.py
to download encoder.py
and model_pb2.py
to your working directory.
Finally, run this script with
python encoder.py -q ./model.h5
to finally produce a model file (model.bin
) understandable by Keras.js.
Running a Neural Network in the Browser
The neural network model should now be ready for Keras.js, so let’s take look at how it would be used to predict the class of a single data sample.
As I already mentioned, you can find a fully worked out example in the frontend/
folder of the companion repository, but this section will highlight the most relevant bits that you can use in your own setup.
The data sample we’ll work with is one of the training set images, x_train[135]
, which just happens to represent the digit 3:
To get this data to JavaScript, we just have to flatten the image array. For this demo, I’ve just copy-pasted the printout of
sample = list(x_train[135].reshape((28*28,)))
print(sample)
into ./sample.js, which just exports a JavaScript array:
export default [
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
// ...
0.57254905, 0.9882353, 0.99215686, 0.9882353, 0.9882353, 0.9882353, 0.7254902,
// ...
];
With the sample in place, you have to copy over model.bin
to somewhere accessible to your web server, and pass it as the filepath
option to an instance of the Model
class from Keras.js:
import { Model } from 'keras-js';
import sample from './sample';
document.addEventListener('DOMContentLoaded', () => {
document.write('Loading...');
});
// Make sure to copy model.bin to the public directory.
const model = new Model({
filepath: 'model.bin',
});
// Perform a prediction and write the results to the console.
model.ready()
.then(() => model.predict({
input: new Float32Array(sample),
}))
.then(({ output }) => {
let predictionProbability = -1;
let predictedDigit = null;
Object.entries(output).forEach(([digit, probability]) => {
if (probability > predictionProbability) {
predictionProbability = probability;
predictedDigit = digit;
}
});
document.write(
`Predicted ${predictedDigit} with probability ${predictionProbability.toFixed(3)}.`,
);
})
.catch((error) => {
console.log(error);
});
When you visit the page using that script, you should see the document body transition from Loading...
to a message like
Predicted 3 with probability 0.297.
Of course, you will end up with a better prediction probability if you spend any real time training your network!
Project Configuration
Since this article was originally published, Keras.js has gone through a few minor revisions and changed its interface a bit.
Various problems arise when using incompatible package versions, so take special care when installing the packages.
The code in this article works with keras-js@1.0.3
on NPM, which you can install via Yarn:
yarn add keras-js@1.0.3
In addition, if you get an error message like the following when you run webpack
,
ERROR in ./node_modules/keras-js/lib/Model.js
Module not found: Error: Can't resolve 'fs' in '~/keras-weight-transfer/fronterend/node_modules/keras-js/lib'
@ ./node_modules/keras-js/lib/Model.js 211:50-63
@ ./node_modules/keras-js/lib/index.js
@ ./entry.js
then you may have to add
node: {
fs: 'empty',
}
to your webpack config. Check out a working webpack config file over in the companion repo, and please let us know if you encounter any other issues in the comments below.
About Intoli
We are a company specializing in browser automation and web scraping. In addition to writing short how-to articles like this one, we also write more in-depth pieces on a variety of technical topics, so check out the rest of our blog and consider subscribing to our mailing list!
Suggested Articles
If you enjoyed this article, then you might also enjoy these related ones.
Performing Efficient Broad Crawls with the AOPIC Algorithm
Learn how to estimate page importance and allocate bandwidth during a broad crawl.
Breaking Out of the Chrome/WebExtension Sandbox
A short guide to breaking out of the WebExtension content script sandbox.
User-Agents — Generating random user agents using Google Analytics and CircleCI
A free dataset and JavaScript library for generating random user agents that are always current.
Comments