Handwritten Digit Classifier

CNN trained on MNIST — running live in your browser

Try It

Draw a digit (0–9)
?
Loading model...

Description

What it is: A convolutional neural network trained on the MNIST dataset of 60,000 handwritten digit images. The model reaches ~99% test accuracy in 5 epochs. After training it is exported to ONNX format and loaded directly in the browser using ONNX Runtime Web — no server, no API, all inference happens on your device.

Why it matters: End-to-end ML from raw PyTorch training through ONNX export to in-browser inference is the full production path for deploying models without a backend. The same pipeline works for any model that can be exported to ONNX.

Training platform: Google Colab (GPU/TPU). The script auto-detects the available accelerator and adapts accordingly. Weights are exported to mnist_classifier.onnx and hosted on HuggingFace.

CNN Architecture

  MNISTNet — ~93,000 parameters
  ══════════════════════════════════════════════════

  Input: (1, 28, 28)  ←  greyscale digit image
         │
         ▼
  ┌─────────────────────────────────────────────┐
  │  Conv2d(1 → 32, kernel 3×3, padding 1)      │
  │  ReLU                                        │
  │  MaxPool2d(2×2)                              │
  │  output: (32, 14, 14)                        │
  └──────────────────────┬──────────────────────┘
                         │
                         ▼
  ┌─────────────────────────────────────────────┐
  │  Conv2d(32 → 64, kernel 3×3, padding 1)     │
  │  ReLU                                        │
  │  MaxPool2d(2×2)                              │
  │  output: (64, 7, 7)                          │
  └──────────────────────┬──────────────────────┘
                         │
                         ▼
  ┌─────────────────────────────────────────────┐
  │  Flatten → 3,136 features                   │
  │  Dropout(0.25)                               │
  │  Linear(3136 → 128)  ReLU                   │
  │  Dropout(0.25)                               │
  │  Linear(128 → 10)                            │
  └──────────────────────┬──────────────────────┘
                         │
                         ▼
  Output: 10 logits — one per digit class (0–9)
  argmax → predicted digit

  ──────────────────────────────────────────────
  Training: Adam, lr=1e-3, CosineAnnealing LR
  Data:     60,000 train / 10,000 test images
  Accuracy: ~99% on test set after 5 epochs
Two conv+pool blocks extract spatial features; two linear layers classify. Dropout regularises to prevent overfitting on the small dataset.

Browser Inference Pipeline

  Canvas → Prediction Pipeline
  ══════════════════════════════════════════════════

  User draws on 280×280 canvas
         │
         ▼
  ┌───────────────────────────────────────────┐
  │  Downsample: 280×280 → 28×28              │
  │  Average each 10×10 block of pixels       │
  │  (reads alpha channel: stroke=1, bg=0)    │
  └─────────────────────┬─────────────────────┘
                        │
                        ▼
  ┌───────────────────────────────────────────┐
  │  Normalize                                │
  │  pixel = (value − 0.1307) / 0.3081       │
  │  (same stats used during training)        │
  └─────────────────────┬─────────────────────┘
                        │
                        ▼
  ┌───────────────────────────────────────────┐
  │  Float32Array [1, 1, 28, 28]              │
  │  fed to ONNX Runtime Web session          │
  │  (WASM backend — runs fully in browser)   │
  └─────────────────────┬─────────────────────┘
                        │
                        ▼
  10 logits → softmax → probabilities
  argmax → predicted digit + confidence
No server round-trip. The ONNX model is fetched once from HuggingFace, then all inference runs locally in WebAssembly.

Dev Notes

Tech Stack

PyTorch + torchvision for training, onnxscript for ONNX export (required by PyTorch's new exporter in 2.x), onnxruntime-web for browser inference via WebAssembly.

ONNX .data File

PyTorch 2.x's new ONNX exporter splits weights into a separate .onnx.data file by default. The training script consolidates them back into a single file using onnx.save_model(..., save_as_external_data=False).

WASM Paths

When loading onnxruntime-web as an ES module from a CDN, the WASM files need to be pointed back to the CDN explicitly via ort.env.wasm.wasmPaths. Without this the browser looks for .wasm files next to the HTML and fails silently.

Download

mnist_classifier.onnx

The trained CNN exported to ONNX format — the same file that runs in-browser above. Load it with onnxruntime in Python or any ONNX-compatible runtime.

ONNX~1.7 MB

↓ Download mnist_classifier.onnx