Jacklu0831 / Classify-Anything

A train-on-the-fly browser app that lets the users classify their own webcam captures (TF.js, HTML, CSS, JS)

Home Page:https://jacklu0831.github.io/Classify-Anything/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Classify Anything!

A train-on-the-fly real-time webcam classifier in web browser I made after learning TensorFlow.js for model deployment. Built with TF.js, HTML, JS, and CSS. I used Transfer Learning with MobileNet as the base model and a KNN Classifier on top of the extracted features, the model inference makes accurate predictions with minimal training data given.

Since the algorithm constantly predicts the live video stream and outputs its prediction above the buttons, MobileNet was chosen due to being light weight (Depth Separable Convolution). However, it does not have the best accuracy in object detection/classification. This issue is mitigated by limiting the scope of the problem to a few classes/types of images with the trainable KNN classifier. The math is explained in section Background.


Demo

Watch the 3am Video Demo
Video Demo

In the video demo, 5 classes/types of images were included. I first clicked on the 5 classes buttons to capture classes of training images for the KNN classifer, then I evaluated the trained model by examining the predictions it makes as I switch between different classes. Focus on how the prediction the model makes changes based on the classes if detects. Note that the classes of images can be different poses, expressions, objects, environments, hand gestures, written letters... sky ain't the limit.

CLASS 1 My head
CLASS 2 My head with blue ray glasses
CLASS 3 My head with blue ray glasses and a hat
CLASS 4 My head with blue ray glasses and a hat and a headphone
CLASS 5 My head with blue ray glasses and a hat and a headphone and a big bottle of water


Background

Why is MobileNet "light weight"?

This was actually a question I had when first reading into why the MobileNet much faster than other object detection networks. The answer is Depthwise Separable Convolution, which breaks down the standard convolution into two parts. It first uses a depthwise convolution (filtering stage) and then uses a pointwise convolution (combining stage). Let's look compare the computational expense of using Standard Convolution and Depthwise Separable Convolution.

Based on the parameters in the picture above (note Dg is simply the size length of output), in the standard convolution, N number of kernels perform convolution on M channels of the Df by Df tensor EACH. Therefore, if I apply 32 number of 3x3 kernels to a 128x128 colored image, it will result in the expense of 32x(126x126)x(3x3)x3 = 13716864 multiplications.

Depthwise separable convolution can be broken into two stages. The first stage is depthwise convolution, where M number of kernels with depth of 1 each are applied to each layer of the input tensor, decreasing the expense by a factor of N. Then, the pointwise convolution uses N number of 1x1xM "stick-like" kernels to integrate the info that the channels of the output of depthwise convolution carries. The final output tensor results in the same size as the standard convolution. Adding the two stages up, the same scenario as calculated above could take (126x126)x(3x3)x3 + 32x(126x126)x3 = 1952748 multiplications. If we divide it by the standard, in this case, depthwise convolution only takes 14.2% of the expense!

Of course, a generalized formula is always better. Based on the derived formula above, 1/32 + 1/(3x3) = 14.2%. Real convolutional neural networks contain much bigger numbers and often also sizes of kernels, which only magnifies the differences between the expense of using standard convolutional layers and depthwise separable convolutional layers.

However, the trade off between accuracy and time has always been one of the biggest challenge in real-time object detection. The depthwise separable convolutions reduces the number of parameters in the convolution. As such, for a small model, the model capacity may be decreased significantly if the 2D convolutions are replaced by depthwise separable convolutions. As a result, the model may become sub-optimal. The MobileNet does not perform as well as YOLOv3 and SSD when I tried to use it for object detection. After the MobileNet efficiently extracts features from the input data, that is where KNN classifier comes along.

What is K-Nearest Neightbor (KNN) Classifier?

The KNN classifier is far simpler than depthwise separable convolution as it is a basic supervised machine learning technique. The above image contains three categories of training data, which in our case is from everytime you click on the class buttons, giving the classifier a labeled image each time. When a live captured webcam frame is passed into the network, the MobileNet is responsible for "simplifying" the image into a collection of features. It is then passed into the KNN classifier, where the classifier finds the nearest K number of labeled data, how many points belong to each category (5 categories for our case), and produces a prediction each frame. Our model uses the default value k=3 according to the TensorFlow.js documentation.

A related, more advanced topic is K-means-clustering, an unsupervised machine learning algorithm for data classification, I implemented it here in Matlab.


Try it Yourself

Refer to the demo video: everytime a class button is pressed, a screenshot of the video frame go through MobileNet and is added to the KNN classifier training set under the category you pressed. The algorithm constantly predicts the live video stream and outputs its prediction above the buttons.

Set Up

Simply clone/fork the repo and build index.html. This version of the project uses script tag so tensorflow.js does not have to be downloaded to your environment. The project was tested to fully function in Google Chrome.

Resources

About

A train-on-the-fly browser app that lets the users classify their own webcam captures (TF.js, HTML, CSS, JS)

https://jacklu0831.github.io/Classify-Anything/


Languages

Language:JavaScript 55.7%Language:HTML 30.4%Language:CSS 13.9%