LeNet on MNIST + readme update (#12)

* LeNet on MNIST + readme update

* tanh + remove device toggle

* remove device entirely
This commit is contained in:
Awni Hannun
2024-03-04 14:16:20 -08:00
committed by GitHub
parent dfc9f2fc01
commit 4ed4ec69e7
8 changed files with 56 additions and 86 deletions

View File

@@ -12,9 +12,6 @@ struct ContentView: View {
// the training loop
@State var trainer = Trainer()
// toggle for cpu/gpu training
@State var cpu = true
var body: some View {
VStack {
Spacer()
@@ -30,13 +27,10 @@ struct ContentView: View {
Button("Train") {
Task {
try! await trainer.run(device: cpu ? .cpu : .gpu)
try! await trainer.run()
}
}
Toggle("CPU", isOn: $cpu)
.frame(maxWidth: 150)
Spacer()
}
Spacer()
@@ -50,12 +44,10 @@ class Trainer {
var messages = [String]()
func run(device: Device = .cpu) async throws {
func run() async throws {
// Note: this is pretty close to the code in `mnist-tool`, just
// wrapped in an Observable to make it easy to display in SwiftUI
Device.setDefault(device: device)
// download & load the training data
let url = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true)
try await download(into: url)
@@ -67,9 +59,7 @@ class Trainer {
let testLabels = data[.init(.test, .labels)]!
// create the model with random weights
let model = MLP(
layers: 2, inputDimensions: trainImages.dim(-1), hiddenDimensions: 32,
outputDimensions: 10)
let model = LeNet()
eval(model.parameters())
// the training loop

View File

@@ -1,13 +1,13 @@
# MNISTTrainer
This is an example showing how to do model training on both macOS and iOS.
This will download the MNIST training data, create a new models and train
it. It will show the timing per epoch and the test accuracy as it trains.
This is an example of model training that works on both macOS and iOS.
The example will download the MNIST training data, create a LeNet, and train
it. It will show the epoch time and test accuracy as it trains.
You will need to set the Team on the MNISTTrainer target in order to build and
run on iOS.
Some notes about the setup:
- this will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
- the website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist
- This will download test data over the network so MNISTTrainer -> Signing & Capabilities has the "Outgoing Connections (Client)" set in the App Sandbox
- The website it connects to uses http rather than https so it has a "App Transport Security Settings" in the Info.plist

View File

@@ -43,13 +43,13 @@ let files = [
name: "train-images-idx3-ubyte.gz",
offset: 16,
convert: {
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
$0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
}),
FileKind(.test, .images): LoadInfo(
name: "t10k-images-idx3-ubyte.gz",
offset: 16,
convert: {
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
$0.reshaped([-1, 28, 28, 1]).asType(.float32) / 255.0
}),
FileKind(.training, .labels): LoadInfo(
name: "train-labels-idx1-ubyte.gz",

View File

@@ -6,36 +6,43 @@ import MLXNN
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
public class MLP: Module, UnaryLayer {
public class LeNet: Module, UnaryLayer {
@ModuleInfo var layers: [Linear]
@ModuleInfo var conv1: Conv2d
@ModuleInfo var conv2: Conv2d
@ModuleInfo var pool1: MaxPool2d
@ModuleInfo var pool2: MaxPool2d
@ModuleInfo var fc1: Linear
@ModuleInfo var fc2: Linear
@ModuleInfo var fc3: Linear
public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) {
let layerSizes =
[inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [
outputDimensions
]
self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst())
.map {
Linear($0, $1)
}
override public init() {
conv1 = Conv2d(inputChannels: 1, outputChannels: 6, kernelSize: 5, padding: 2)
conv2 = Conv2d(inputChannels: 6, outputChannels: 16, kernelSize: 5, padding: 0)
pool1 = MaxPool2d(kernelSize: 2, stride: 2)
pool2 = MaxPool2d(kernelSize: 2, stride: 2)
fc1 = Linear(16 * 5 * 5, 120)
fc2 = Linear(120, 84)
fc3 = Linear(84, 10)
}
public func callAsFunction(_ x: MLXArray) -> MLXArray {
var x = x
for l in layers.dropLast() {
x = relu(l(x))
}
return layers.last!(x)
x = pool1(tanh(conv1(x)))
x = pool2(tanh(conv2(x)))
x = flattened(x, start: 1)
x = tanh(fc1(x))
x = tanh(fc2(x))
x = fc3(x)
return x
}
}
public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
public func loss(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
crossEntropy(logits: model(x), targets: y, reduction: .mean)
}
public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
public func eval(model: LeNet, x: MLXArray, y: MLXArray) -> MLXArray {
mean(argMax(model(x), axis: 1) .== y)
}

View File

@@ -1,13 +1,11 @@
# MNIST
This is a port of the MNIST model and training code from:
- https://github.com/ml-explore/mlx-examples/blob/main/mnist
This is a port of the MNIST training code from the [Python MLX example](https://github.com/ml-explore/mlx-examples/blob/main/mnist). This example uses a [LeNet](https://en.wikipedia.org/wiki/LeNet) instead of an MLP.
It provides code to:
- download the test/train data
- provides the MNIST model (MLP)
- some functions to shuffle and batch the data
- Download the MNIST test/train data
- Build the LeNet
- Some functions to shuffle and batch the data
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.

View File

@@ -1,37 +1,20 @@
# MLX Swift Examples
Example [mlx-swift](https://github.com/ml-explore/mlx-swift) programs.
Example [MLX Swift](https://github.com/ml-explore/mlx-swift) programs.
## MNISTTrainer
- [MNISTTrainer](Applications/MNISTTrainer/README.md): An example that runs on
both iOS and macOS that downloads MNIST training data and trains a
[LeNet](https://en.wikipedia.org/wiki/LeNet).
An example that runs on both iOS and macOS that downloads MNIST training
data and trains an MNIST model.
- [LLMEval](Applications/LLMEval/README.md): An example that runs on both iOS
and macOS that downloads an LLM and tokenizer from Hugging Face and and
generates text from a given prompt.
- [README](Applications/MNISTTrainer/README.md)
- [LinearModelTraining](Tools/LinearModelTraining/README.md): An example that
trains a simple linear model.
## LLMEval
An example that runs on both iOS and macOS that downloads a LLM model
weights and tokenizer configuration from Hugging Face and evaluates
a prompt in-process.
- [README](Applications/LLMEval/README.md)
## LinearModelTraining
A simple linear model and a training loop.
- [README](Tools/LinearModelTraining/README.md)
## llm-tool
A command line tool for generating text using a variety of Hugging Face models:
- [README](Tools/llm-tool/README.md)
## mnist-tool
A command line tool for training an MNIST (MLP) model:
- [README](Tools/mnist-tool/README.md)
- [llm-tool](Tools/llm-tool/README.md): A command line tool for generating text
using a variety of LLMs available on the Hugging Face hub.
- [mnist-tool](Tools/mnist-tool/README.md): A command line tool for training a
a LeNet on MNIST.

View File

@@ -30,15 +30,11 @@ struct Train: AsyncParsableCommand {
@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0
@Option var layers = 2
@Option var hidden = 32
@Option var batchSize = 256
@Option var epochs = 20
@Option var learningRate: Float = 1e-1
@Option var classes = 10
@Option var device = DeviceType.cpu
@Option var device = DeviceType.gpu
@Flag var compile = false
@@ -62,9 +58,7 @@ struct Train: AsyncParsableCommand {
let testLabels = data[.init(.test, .labels)]!
// create the model
let model = MLP(
layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden,
outputDimensions: classes)
let model = LeNet()
eval(model.parameters())
let lg = valueAndGrad(model: model, loss)

View File

@@ -1,8 +1,6 @@
# mnist-tool
See other README:
- [MNIST](../../Libraries/MNIST/README.md)
See the [MNIST README.md](../../Libraries/MNIST/README.md).
### Building