From 9e18eaa4792b82b3c573003961a886ab76ee90bd Mon Sep 17 00:00:00 2001 From: Rounak Date: Mon, 18 Mar 2024 19:18:41 -0700 Subject: [PATCH] Add MNIST Digit Prediction/Inference (#22) * Add Prediction to MNISTTrainer --- Applications/MNISTTrainer/ContentView.swift | 44 ++++-- .../MNISTTrainer/PredictionView.swift | 127 ++++++++++++++++++ CONTRIBUTING.md | 2 +- mlx-swift-examples.xcodeproj/project.pbxproj | 4 + 4 files changed, 168 insertions(+), 9 deletions(-) create mode 100644 Applications/MNISTTrainer/PredictionView.swift diff --git a/Applications/MNISTTrainer/ContentView.swift b/Applications/MNISTTrainer/ContentView.swift index 4747386..4827d80 100644 --- a/Applications/MNISTTrainer/ContentView.swift +++ b/Applications/MNISTTrainer/ContentView.swift @@ -7,10 +7,9 @@ import MLXRandom import MNIST import SwiftUI -struct ContentView: View { +struct TrainingView: View { - // the training loop - @State var trainer = Trainer() + @Binding var trainer: Trainer var body: some View { VStack { @@ -24,10 +23,16 @@ struct ContentView: View { HStack { Spacer() - - Button("Train") { - Task { - try! await trainer.run() + switch trainer.state { + case .untrained: + Button("Train") { + Task { + try! await trainer.run() + } + } + case .trained(let model), .predict(let model): + Button("Draw a digit") { + trainer.state = .predict(model) } } @@ -39,9 +44,30 @@ struct ContentView: View { } } +struct ContentView: View { + // the training loop + @State var trainer = Trainer() + + var body: some View { + switch trainer.state { + case .untrained, .trained: + TrainingView(trainer: $trainer) + case .predict(let model): + PredictionView(model: model) + } + } +} + @Observable class Trainer { + enum State { + case untrained + case trained(LeNet) + case predict(LeNet) + } + + var state: State = .untrained var messages = [String]() func run() async throws { @@ -101,6 +127,8 @@ class Trainer { ) } } - + await MainActor.run { + state = .trained(model) + } } } diff --git a/Applications/MNISTTrainer/PredictionView.swift b/Applications/MNISTTrainer/PredictionView.swift new file mode 100644 index 0000000..67906f2 --- /dev/null +++ b/Applications/MNISTTrainer/PredictionView.swift @@ -0,0 +1,127 @@ +// +// PredictionView.swift +// MNISTTrainer +// +// Created by Rounak Jain on 3/9/24. +// + +import MLX +import MLXNN +import MNIST +import SwiftUI + +struct Canvas: View { + + @Binding var path: Path + @State var lastPoint: CGPoint? + + var body: some View { + path + .stroke(.white, lineWidth: 10) + .background(.black) + .gesture( + DragGesture(minimumDistance: 0.05) + .onChanged { touch in + add(point: touch.location) + } + .onEnded { touch in + lastPoint = nil + } + ) + } + + func add(point: CGPoint) { + var newPath = path + if let lastPoint { + newPath.move(to: lastPoint) + newPath.addLine(to: point) + } else { + newPath.move(to: point) + } + self.path = newPath + lastPoint = point + } +} + +extension Path { + mutating func center(to newMidPoint: CGPoint) { + let middleX = boundingRect.midX + let middleY = boundingRect.midY + self = offsetBy(dx: newMidPoint.x - middleX, dy: newMidPoint.y - middleY) + } +} + +struct PredictionView: View { + @State var path: Path = Path() + @State var prediction: Int? + let model: LeNet + let canvasSize = 150.0 + let mnistImageSize: CGSize = CGSize(width: 28, height: 28) + + var body: some View { + VStack { + if let prediction { + Text("You've drawn a \(prediction)") + } else { + Text("Draw a digit") + } + Canvas(path: $path) + .frame(width: canvasSize, height: canvasSize) + HStack { + Button("Predict") { + path.center(to: CGPoint(x: canvasSize / 2, y: canvasSize / 2)) + predict() + } + Button("Clear") { + path = Path() + prediction = nil + } + } + } + } + + @MainActor + func predict() { + let imageRenderer = ImageRenderer( + content: Canvas(path: $path).frame(width: 150, height: 150)) + guard + let pixelData = imageRenderer.cgImage?.grayscaleImage(with: mnistImageSize)?.pixelData() + else { + return + } + // modify input vector to match training in MNIST/Files.swift + let x = pixelData.reshaped([1, 28, 28, 1]).asType(.float32) / 255.0 + prediction = argMax(model(x)).item() + } +} + +extension CGImage { + func grayscaleImage(with newSize: CGSize) -> CGImage? { + let colorSpace = CGColorSpaceCreateDeviceGray() + let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue) + + guard + let context = CGContext( + data: nil, + width: Int(newSize.width), + height: Int(newSize.height), + bitsPerComponent: 8, + bytesPerRow: Int(newSize.width), + space: colorSpace, + bitmapInfo: bitmapInfo.rawValue) + else { + return nil + } + context.draw(self, in: CGRect(x: 0, y: 0, width: newSize.width, height: newSize.width)) + return context.makeImage() + } + + func pixelData() -> MLXArray { + guard let data = self.dataProvider?.data else { + return [] + } + let bytePtr = CFDataGetBytePtr(data) + let count = CFDataGetLength(data) + return MLXArray(UnsafeBufferPointer(start: bytePtr, count: count)) + } +} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ffc4f68..e1fbe67 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,7 +14,7 @@ possible. You can also run the formatters manually as follows: ``` - swift-format format --in-place --recursive Libraries Tools + swift-format format --in-place --recursive Libraries Tools Applications ``` or run `pre-commit run --all-files` to check all files in the repo. diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 029f636..f29d014 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -7,6 +7,7 @@ objects = { /* Begin PBXBuildFile section */ + 12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 12305EAE2B9D864400C92FEE /* PredictionView.swift */; }; 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; }; 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; }; 81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; }; @@ -183,6 +184,7 @@ /* End PBXCopyFilesBuildPhase section */ /* Begin PBXFileReference section */ + 12305EAE2B9D864400C92FEE /* PredictionView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PredictionView.swift; sourceTree = ""; }; 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = ""; }; 52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = ""; }; C325DE3F2B648CDB00628871 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; @@ -451,6 +453,7 @@ children = ( C3A8B3C32B92951E0002EFB8 /* Assets.xcassets */, C3A8B3C92B92951E0002EFB8 /* ContentView.swift */, + 12305EAE2B9D864400C92FEE /* PredictionView.swift */, C3A8B3C22B92951E0002EFB8 /* MNISTTrainer-Info.plist */, C3A8B3C72B92951E0002EFB8 /* MNISTTrainer.entitlements */, C3A8B3C42B92951E0002EFB8 /* MNISTTrainerApp.swift */, @@ -866,6 +869,7 @@ isa = PBXSourcesBuildPhase; buildActionMask = 2147483647; files = ( + 12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */, C3A8B3CC2B92951E0002EFB8 /* MNISTTrainerApp.swift in Sources */, C3A8B3CF2B92951E0002EFB8 /* ContentView.swift in Sources */, );