initial commit
This commit is contained in:
63
.circleci/config.yml
Normal file
63
.circleci/config.yml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
version: 2.1
|
||||||
|
|
||||||
|
orbs:
|
||||||
|
apple: ml-explore/pr-approval@0.1.0
|
||||||
|
|
||||||
|
parameters:
|
||||||
|
nightly_build:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
weekly_build:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
mac_build_and_test:
|
||||||
|
macos:
|
||||||
|
xcode: 15.2.0
|
||||||
|
resource_class: macos.m1.medium.gen1
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run: git submodule sync
|
||||||
|
- run: git submodule update --init
|
||||||
|
- run:
|
||||||
|
name: Run style checks
|
||||||
|
command: |
|
||||||
|
pip install pre-commit
|
||||||
|
brew install swift-format
|
||||||
|
pre-commit run --all
|
||||||
|
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
||||||
|
- run:
|
||||||
|
name: Build Examples
|
||||||
|
command: |
|
||||||
|
xcodebuild -version
|
||||||
|
xcrun --show-sdk-build-version
|
||||||
|
swift --version
|
||||||
|
xcodebuild -scheme llm-tool
|
||||||
|
xcodebuild -scheme mnist-tool
|
||||||
|
|
||||||
|
workflows:
|
||||||
|
build_and_test:
|
||||||
|
when:
|
||||||
|
and:
|
||||||
|
- matches:
|
||||||
|
pattern: "^(?!pull/)[-\\w]+$"
|
||||||
|
value: << pipeline.git.branch >>
|
||||||
|
- not: << pipeline.parameters.nightly_build >>
|
||||||
|
- not: << pipeline.parameters.weekly_build >>
|
||||||
|
jobs:
|
||||||
|
- mac_build_and_test
|
||||||
|
|
||||||
|
prb:
|
||||||
|
when:
|
||||||
|
matches:
|
||||||
|
pattern: "^pull/\\d+(/head)?$"
|
||||||
|
value: << pipeline.git.branch >>
|
||||||
|
jobs:
|
||||||
|
- hold:
|
||||||
|
type: approval
|
||||||
|
- apple/authenticate:
|
||||||
|
context: pr-approval
|
||||||
|
- mac_build_and_test:
|
||||||
|
requires: [ hold ]
|
||||||
90
.gitignore
vendored
Normal file
90
.gitignore
vendored
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# Xcode
|
||||||
|
#
|
||||||
|
# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
|
||||||
|
|
||||||
|
## User settings
|
||||||
|
xcuserdata/
|
||||||
|
|
||||||
|
## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9)
|
||||||
|
*.xcscmblueprint
|
||||||
|
*.xccheckout
|
||||||
|
|
||||||
|
## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4)
|
||||||
|
build/
|
||||||
|
DerivedData/
|
||||||
|
*.moved-aside
|
||||||
|
*.pbxuser
|
||||||
|
!default.pbxuser
|
||||||
|
*.mode1v3
|
||||||
|
!default.mode1v3
|
||||||
|
*.mode2v3
|
||||||
|
!default.mode2v3
|
||||||
|
*.perspectivev3
|
||||||
|
!default.perspectivev3
|
||||||
|
|
||||||
|
## Obj-C/Swift specific
|
||||||
|
*.hmap
|
||||||
|
|
||||||
|
## App packaging
|
||||||
|
*.ipa
|
||||||
|
*.dSYM.zip
|
||||||
|
*.dSYM
|
||||||
|
|
||||||
|
## Playgrounds
|
||||||
|
timeline.xctimeline
|
||||||
|
playground.xcworkspace
|
||||||
|
|
||||||
|
# Swift Package Manager
|
||||||
|
#
|
||||||
|
# Add this line if you want to avoid checking in source code from Swift Package Manager dependencies.
|
||||||
|
# Packages/
|
||||||
|
# Package.pins
|
||||||
|
# Package.resolved
|
||||||
|
# *.xcodeproj
|
||||||
|
#
|
||||||
|
# Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata
|
||||||
|
# hence it is not needed unless you have added a package configuration file to your project
|
||||||
|
# .swiftpm
|
||||||
|
|
||||||
|
.build/
|
||||||
|
|
||||||
|
# CocoaPods
|
||||||
|
#
|
||||||
|
# We recommend against adding the Pods directory to your .gitignore. However
|
||||||
|
# you should judge for yourself, the pros and cons are mentioned at:
|
||||||
|
# https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control
|
||||||
|
#
|
||||||
|
# Pods/
|
||||||
|
#
|
||||||
|
# Add this line if you want to avoid checking in source code from the Xcode workspace
|
||||||
|
# *.xcworkspace
|
||||||
|
|
||||||
|
# Carthage
|
||||||
|
#
|
||||||
|
# Add this line if you want to avoid checking in source code from Carthage dependencies.
|
||||||
|
# Carthage/Checkouts
|
||||||
|
|
||||||
|
Carthage/Build/
|
||||||
|
|
||||||
|
# Accio dependency management
|
||||||
|
Dependencies/
|
||||||
|
.accio/
|
||||||
|
|
||||||
|
# fastlane
|
||||||
|
#
|
||||||
|
# It is recommended to not store the screenshots in the git repo.
|
||||||
|
# Instead, use fastlane to re-generate the screenshots whenever they are needed.
|
||||||
|
# For more information about the recommended setup visit:
|
||||||
|
# https://docs.fastlane.tools/best-practices/source-control/#source-control
|
||||||
|
|
||||||
|
fastlane/report.xml
|
||||||
|
fastlane/Preview.html
|
||||||
|
fastlane/screenshots/**/*.png
|
||||||
|
fastlane/test_output
|
||||||
|
|
||||||
|
# Code Injection
|
||||||
|
#
|
||||||
|
# After new code Injection tools there's a generated folder /iOSInjectionProject
|
||||||
|
# https://github.com/johnno1962/injectionforxcode
|
||||||
|
|
||||||
|
iOSInjectionProject/
|
||||||
6
.pre-commit-config.yaml
Normal file
6
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/slessans/pre-commit-swift-format
|
||||||
|
rev: ""
|
||||||
|
hooks:
|
||||||
|
- id: swift-format
|
||||||
|
args: ["--configuration", ".swift-format"]
|
||||||
7
.swift-format
Normal file
7
.swift-format
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"indentation": {
|
||||||
|
"spaces": 4
|
||||||
|
},
|
||||||
|
"spacesAroundRangeFormationOperators": true,
|
||||||
|
}
|
||||||
77
Libraries/LLM/Configuration.swift
Normal file
77
Libraries/LLM/Configuration.swift
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
public enum StringOrNumber: Codable, Equatable {
|
||||||
|
case string(String)
|
||||||
|
case float(Float)
|
||||||
|
|
||||||
|
public init(from decoder: Decoder) throws {
|
||||||
|
let values = try decoder.singleValueContainer()
|
||||||
|
|
||||||
|
if let v = try? values.decode(Float.self) {
|
||||||
|
self = .float(v)
|
||||||
|
} else {
|
||||||
|
let v = try values.decode(String.self)
|
||||||
|
self = .string(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public func encode(to encoder: Encoder) throws {
|
||||||
|
var container = encoder.singleValueContainer()
|
||||||
|
switch self {
|
||||||
|
case .string(let v): try container.encode(v)
|
||||||
|
case .float(let v): try container.encode(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum ModelType: String, Codable {
|
||||||
|
case mistral
|
||||||
|
case llama
|
||||||
|
case phi
|
||||||
|
case gemma
|
||||||
|
|
||||||
|
func createModel(configuration: URL) throws -> LLMModel {
|
||||||
|
switch self {
|
||||||
|
case .mistral, .llama:
|
||||||
|
let configuration = try JSONDecoder().decode(
|
||||||
|
LlamaConfiguration.self, from: Data(contentsOf: configuration))
|
||||||
|
return LlamaModel(configuration)
|
||||||
|
case .phi:
|
||||||
|
let configuration = try JSONDecoder().decode(
|
||||||
|
PhiConfiguration.self, from: Data(contentsOf: configuration))
|
||||||
|
return PhiModel(configuration)
|
||||||
|
case .gemma:
|
||||||
|
let configuration = try JSONDecoder().decode(
|
||||||
|
GemmaConfiguration.self, from: Data(contentsOf: configuration))
|
||||||
|
return GemmaModel(configuration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public struct BaseConfiguration: Codable {
|
||||||
|
let modelType: ModelType
|
||||||
|
|
||||||
|
public struct Quantization: Codable {
|
||||||
|
public init(groupSize: Int, bits: Int) {
|
||||||
|
self.groupSize = groupSize
|
||||||
|
self.bits = bits
|
||||||
|
}
|
||||||
|
|
||||||
|
let groupSize: Int
|
||||||
|
let bits: Int
|
||||||
|
|
||||||
|
enum CodingKeys: String, CodingKey {
|
||||||
|
case groupSize = "group_size"
|
||||||
|
case bits = "bits"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var quantization: Quantization?
|
||||||
|
|
||||||
|
enum CodingKeys: String, CodingKey {
|
||||||
|
case modelType = "model_type"
|
||||||
|
case quantization
|
||||||
|
}
|
||||||
|
}
|
||||||
273
Libraries/LLM/Gemma.swift
Normal file
273
Libraries/LLM/Gemma.swift
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXNN
|
||||||
|
|
||||||
|
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py
|
||||||
|
|
||||||
|
// specialized norm for gemma
|
||||||
|
private class RMSNorm: Module, UnaryLayer {
|
||||||
|
|
||||||
|
let weight: MLXArray
|
||||||
|
let eps: Float
|
||||||
|
|
||||||
|
public init(dimensions: Int, eps: Float = 1e-5) {
|
||||||
|
self.weight = MLXArray.ones([dimensions])
|
||||||
|
self.eps = eps
|
||||||
|
super.init()
|
||||||
|
}
|
||||||
|
|
||||||
|
func norm(_ x: MLXArray) -> MLXArray {
|
||||||
|
let S = 1.0 / sqrt(Float(x.dim(-1)))
|
||||||
|
|
||||||
|
let n = (x * S).square().sum(axis: -1, keepDims: true)
|
||||||
|
return rsqrt(n + eps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
let output = norm(x.asType(Float.self)).asType(x.dtype)
|
||||||
|
return (1 + weight) * output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class Attention: Module {
|
||||||
|
|
||||||
|
let args: GemmaConfiguration
|
||||||
|
let repeats: Int
|
||||||
|
let scale: Float
|
||||||
|
|
||||||
|
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||||
|
@ModuleInfo(key: "k_proj") var wk: Linear
|
||||||
|
@ModuleInfo(key: "v_proj") var wv: Linear
|
||||||
|
@ModuleInfo(key: "o_proj") var wo: Linear
|
||||||
|
|
||||||
|
let rope: RoPE
|
||||||
|
|
||||||
|
public init(_ args: GemmaConfiguration) {
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
let dim = args.hiddenSize
|
||||||
|
let heads = args.attentionHeads
|
||||||
|
let kvHeads = args.kvHeads
|
||||||
|
|
||||||
|
self.repeats = heads / kvHeads
|
||||||
|
|
||||||
|
let headDim = args.headDimensions
|
||||||
|
self.scale = pow(Float(headDim), -0.5)
|
||||||
|
|
||||||
|
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
|
||||||
|
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
||||||
|
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
||||||
|
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
|
||||||
|
|
||||||
|
self.rope = RoPE(
|
||||||
|
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
|
let (B, L) = (x.dim(0), x.dim(1))
|
||||||
|
|
||||||
|
var queries = wq(x)
|
||||||
|
var keys = wk(x)
|
||||||
|
var values = wv(x)
|
||||||
|
|
||||||
|
// prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if repeats > 1 {
|
||||||
|
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
|
||||||
|
values = MLXArray.repeat(values, count: repeats, axis: 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if let (keyCache, valueCache) = cache {
|
||||||
|
queries = rope(queries, offset: keyCache.dim(2))
|
||||||
|
keys = rope(keys, offset: keyCache.dim(2))
|
||||||
|
keys = concatenated([keyCache, keys], axis: 2)
|
||||||
|
values = concatenated([valueCache, values], axis: 2)
|
||||||
|
} else {
|
||||||
|
queries = rope(queries)
|
||||||
|
keys = rope(keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2))
|
||||||
|
if let mask {
|
||||||
|
scores = scores + mask
|
||||||
|
}
|
||||||
|
|
||||||
|
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
|
||||||
|
|
||||||
|
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
||||||
|
|
||||||
|
return (wo(output), (keys, values))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class MLP: Module, UnaryLayer {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "gate_proj") var gate: Linear
|
||||||
|
@ModuleInfo(key: "down_proj") var down: Linear
|
||||||
|
@ModuleInfo(key: "up_proj") var up: Linear
|
||||||
|
|
||||||
|
public init(dimensions: Int, hiddenDimensions: Int) {
|
||||||
|
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||||
|
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
|
||||||
|
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
down(gelu(gate(x)) * up(x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class TransformerBlock: Module {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "self_attn") var attention: Attention
|
||||||
|
let mlp: MLP
|
||||||
|
|
||||||
|
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
|
||||||
|
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
|
||||||
|
|
||||||
|
public init(_ args: GemmaConfiguration) {
|
||||||
|
self._attention.wrappedValue = Attention(args)
|
||||||
|
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
|
||||||
|
self._inputLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
|
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
||||||
|
let h = x + r
|
||||||
|
r = mlp(postAttentionLayerNorm(h))
|
||||||
|
let out = h + r
|
||||||
|
return (out, cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class GemmaModelInner: Module {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
||||||
|
|
||||||
|
fileprivate let layers: [TransformerBlock]
|
||||||
|
fileprivate let norm: RMSNorm
|
||||||
|
|
||||||
|
let hiddenScale: Float
|
||||||
|
|
||||||
|
public init(_ args: GemmaConfiguration) {
|
||||||
|
precondition(args.vocabularySize > 0)
|
||||||
|
|
||||||
|
self._embedTokens.wrappedValue = Embedding(
|
||||||
|
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
||||||
|
|
||||||
|
self.hiddenScale = pow(Float(args.hiddenSize), 0.5)
|
||||||
|
|
||||||
|
self.layers = (0 ..< args.hiddenLayers)
|
||||||
|
.map { _ in
|
||||||
|
TransformerBlock(args)
|
||||||
|
}
|
||||||
|
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
) {
|
||||||
|
var h = embedTokens(inputs)
|
||||||
|
h = h * hiddenScale
|
||||||
|
|
||||||
|
var mask: MLXArray? = nil
|
||||||
|
if h.dim(1) > 1 {
|
||||||
|
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
|
||||||
|
mask = mask?.asType(h.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
var newCache = [(MLXArray, MLXArray)]()
|
||||||
|
|
||||||
|
for (i, layer) in layers.enumerated() {
|
||||||
|
var cacheUpdate: (MLXArray, MLXArray)
|
||||||
|
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
|
||||||
|
newCache.append(cacheUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (norm(h), newCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class GemmaModel: Module, LLMModel {
|
||||||
|
|
||||||
|
let model: GemmaModelInner
|
||||||
|
|
||||||
|
public init(_ args: GemmaConfiguration) {
|
||||||
|
self.model = GemmaModelInner(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
) {
|
||||||
|
var (out, cache) = model(inputs, cache: cache)
|
||||||
|
out = matmul(out, model.embedTokens.weight.T)
|
||||||
|
return (out, cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public struct GemmaConfiguration: Codable {
|
||||||
|
|
||||||
|
var hiddenSize: Int
|
||||||
|
var hiddenLayers: Int
|
||||||
|
var intermediateSize: Int
|
||||||
|
var attentionHeads: Int
|
||||||
|
var headDimensions: Int
|
||||||
|
var rmsNormEps: Float
|
||||||
|
var vocabularySize: Int
|
||||||
|
var kvHeads: Int
|
||||||
|
var ropeTheta: Float = 10_000
|
||||||
|
var ropeTraditional: Bool = false
|
||||||
|
|
||||||
|
enum CodingKeys: String, CodingKey {
|
||||||
|
case hiddenSize = "hidden_size"
|
||||||
|
case hiddenLayers = "num_hidden_layers"
|
||||||
|
case intermediateSize = "intermediate_size"
|
||||||
|
case attentionHeads = "num_attention_heads"
|
||||||
|
case headDimensions = "head_dim"
|
||||||
|
case rmsNormEps = "rms_norm_eps"
|
||||||
|
case vocabularySize = "vocab_size"
|
||||||
|
case kvHeads = "num_key_value_heads"
|
||||||
|
case ropeTheta = "rope_theta"
|
||||||
|
case ropeTraditional = "rope_traditional"
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(from decoder: Decoder) throws {
|
||||||
|
// custom implementation to handle optional keys with required values
|
||||||
|
let container: KeyedDecodingContainer<CodingKeys> = try decoder.container(
|
||||||
|
keyedBy: CodingKeys.self)
|
||||||
|
|
||||||
|
self.hiddenSize = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.hiddenSize)
|
||||||
|
self.hiddenLayers = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.hiddenLayers)
|
||||||
|
self.intermediateSize = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.intermediateSize)
|
||||||
|
self.attentionHeads = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.attentionHeads)
|
||||||
|
self.headDimensions = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.headDimensions)
|
||||||
|
self.rmsNormEps = try container.decode(
|
||||||
|
Float.self, forKey: CodingKeys.rmsNormEps)
|
||||||
|
self.vocabularySize = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.vocabularySize)
|
||||||
|
self.kvHeads = try container.decode(Int.self, forKey: CodingKeys.kvHeads)
|
||||||
|
self.ropeTheta =
|
||||||
|
try container.decodeIfPresent(Float.self, forKey: CodingKeys.ropeTheta)
|
||||||
|
?? 10_000
|
||||||
|
self.ropeTraditional =
|
||||||
|
try container.decodeIfPresent(
|
||||||
|
Bool.self, forKey: CodingKeys.ropeTraditional) ?? false
|
||||||
|
}
|
||||||
|
}
|
||||||
1
Libraries/LLM/LLM.h
Normal file
1
Libraries/LLM/LLM.h
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
12
Libraries/LLM/LLMModel.swift
Normal file
12
Libraries/LLM/LLMModel.swift
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXNN
|
||||||
|
|
||||||
|
// Interface for all LLM Models
|
||||||
|
public protocol LLMModel: Module {
|
||||||
|
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
)
|
||||||
|
}
|
||||||
263
Libraries/LLM/Llama.swift
Normal file
263
Libraries/LLM/Llama.swift
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXNN
|
||||||
|
|
||||||
|
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
|
||||||
|
|
||||||
|
private class Attention: Module {
|
||||||
|
|
||||||
|
let args: LlamaConfiguration
|
||||||
|
let repeats: Int
|
||||||
|
let scale: Float
|
||||||
|
|
||||||
|
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||||
|
@ModuleInfo(key: "k_proj") var wk: Linear
|
||||||
|
@ModuleInfo(key: "v_proj") var wv: Linear
|
||||||
|
@ModuleInfo(key: "o_proj") var wo: Linear
|
||||||
|
|
||||||
|
let rope: RoPE
|
||||||
|
|
||||||
|
public init(_ args: LlamaConfiguration) {
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
let dim = args.hiddenSize
|
||||||
|
let heads = args.attentionHeads
|
||||||
|
let kvHeads = args.kvHeads
|
||||||
|
|
||||||
|
self.repeats = heads / kvHeads
|
||||||
|
|
||||||
|
let headDim = args.hiddenSize / heads
|
||||||
|
self.scale = pow(Float(headDim), -0.5)
|
||||||
|
|
||||||
|
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
|
||||||
|
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
||||||
|
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
||||||
|
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
|
||||||
|
|
||||||
|
let ropeScale: Float
|
||||||
|
if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"),
|
||||||
|
let factor = ropeScaling["factor"]
|
||||||
|
{
|
||||||
|
switch factor {
|
||||||
|
case .string:
|
||||||
|
fatalError("ropeScaling.factor must be a float")
|
||||||
|
case .float(let v):
|
||||||
|
ropeScale = 1 / v
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ropeScale = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
self.rope = RoPE(
|
||||||
|
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta,
|
||||||
|
scale: ropeScale)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
|
let (B, L) = (x.dim(0), x.dim(1))
|
||||||
|
|
||||||
|
var queries = wq(x)
|
||||||
|
var keys = wk(x)
|
||||||
|
var values = wv(x)
|
||||||
|
|
||||||
|
// prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if repeats > 1 {
|
||||||
|
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
|
||||||
|
values = MLXArray.repeat(values, count: repeats, axis: 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if let (keyCache, valueCache) = cache {
|
||||||
|
queries = rope(queries, offset: keyCache.dim(2))
|
||||||
|
keys = rope(keys, offset: keyCache.dim(2))
|
||||||
|
keys = concatenated([keyCache, keys], axis: 2)
|
||||||
|
values = concatenated([valueCache, values], axis: 2)
|
||||||
|
} else {
|
||||||
|
queries = rope(queries)
|
||||||
|
keys = rope(keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2))
|
||||||
|
if let mask {
|
||||||
|
scores = scores + mask
|
||||||
|
}
|
||||||
|
|
||||||
|
scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype)
|
||||||
|
|
||||||
|
let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
||||||
|
|
||||||
|
return (wo(output), (keys, values))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class MLP: Module, UnaryLayer {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "gate_proj") var gate: Linear
|
||||||
|
@ModuleInfo(key: "down_proj") var down: Linear
|
||||||
|
@ModuleInfo(key: "up_proj") var up: Linear
|
||||||
|
|
||||||
|
public init(dimensions: Int, hiddenDimensions: Int) {
|
||||||
|
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||||
|
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
|
||||||
|
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
down(silu(gate(x)) * up(x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class TransformerBlock: Module {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "self_attn") var attention: Attention
|
||||||
|
let mlp: MLP
|
||||||
|
|
||||||
|
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
|
||||||
|
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
|
||||||
|
|
||||||
|
public init(_ args: LlamaConfiguration) {
|
||||||
|
self._attention.wrappedValue = Attention(args)
|
||||||
|
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
|
||||||
|
self._inputLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
|
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
||||||
|
let h = x + r
|
||||||
|
r = mlp(postAttentionLayerNorm(h))
|
||||||
|
let out = h + r
|
||||||
|
return (out, cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class LlamaModelInner: Module {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
||||||
|
|
||||||
|
fileprivate let layers: [TransformerBlock]
|
||||||
|
let norm: RMSNorm
|
||||||
|
|
||||||
|
public init(_ args: LlamaConfiguration) {
|
||||||
|
precondition(args.vocabularySize > 0)
|
||||||
|
|
||||||
|
self._embedTokens.wrappedValue = Embedding(
|
||||||
|
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
||||||
|
|
||||||
|
self.layers = (0 ..< args.hiddenLayers)
|
||||||
|
.map { _ in
|
||||||
|
TransformerBlock(args)
|
||||||
|
}
|
||||||
|
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
) {
|
||||||
|
var h = embedTokens(inputs)
|
||||||
|
|
||||||
|
var mask: MLXArray? = nil
|
||||||
|
if h.dim(1) > 1 {
|
||||||
|
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
|
||||||
|
mask = mask?.asType(h.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
var newCache = [(MLXArray, MLXArray)]()
|
||||||
|
|
||||||
|
for (i, layer) in layers.enumerated() {
|
||||||
|
var cacheUpdate: (MLXArray, MLXArray)
|
||||||
|
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
|
||||||
|
newCache.append(cacheUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (norm(h), newCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class LlamaModel: Module, LLMModel {
|
||||||
|
|
||||||
|
let model: LlamaModelInner
|
||||||
|
|
||||||
|
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
||||||
|
|
||||||
|
public init(_ args: LlamaConfiguration) {
|
||||||
|
self.model = LlamaModelInner(args)
|
||||||
|
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
) {
|
||||||
|
let (out, cache) = model(inputs, cache: cache)
|
||||||
|
return (lmHead(out), cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public struct LlamaConfiguration: Codable {
|
||||||
|
|
||||||
|
var hiddenSize: Int
|
||||||
|
var hiddenLayers: Int
|
||||||
|
var intermediateSize: Int
|
||||||
|
var attentionHeads: Int
|
||||||
|
var rmsNormEps: Float
|
||||||
|
var vocabularySize: Int
|
||||||
|
var kvHeads: Int
|
||||||
|
var ropeTheta: Float = 10_000
|
||||||
|
var ropeTraditional: Bool = false
|
||||||
|
var ropeScaling: [String: StringOrNumber]? = nil
|
||||||
|
|
||||||
|
enum CodingKeys: String, CodingKey {
|
||||||
|
case hiddenSize = "hidden_size"
|
||||||
|
case hiddenLayers = "num_hidden_layers"
|
||||||
|
case intermediateSize = "intermediate_size"
|
||||||
|
case attentionHeads = "num_attention_heads"
|
||||||
|
case rmsNormEps = "rms_norm_eps"
|
||||||
|
case vocabularySize = "vocab_size"
|
||||||
|
case kvHeads = "num_key_value_heads"
|
||||||
|
case ropeTheta = "rope_theta"
|
||||||
|
case ropeTraditional = "rope_traditional"
|
||||||
|
case ropeScaling = "rope_scaling"
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(from decoder: Decoder) throws {
|
||||||
|
// custom implementation to handle optional keys with required values
|
||||||
|
let container: KeyedDecodingContainer<LlamaConfiguration.CodingKeys> =
|
||||||
|
try decoder.container(
|
||||||
|
keyedBy: LlamaConfiguration.CodingKeys.self)
|
||||||
|
|
||||||
|
self.hiddenSize = try container.decode(
|
||||||
|
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenSize)
|
||||||
|
self.hiddenLayers = try container.decode(
|
||||||
|
Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenLayers)
|
||||||
|
self.intermediateSize = try container.decode(
|
||||||
|
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
|
||||||
|
self.attentionHeads = try container.decode(
|
||||||
|
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
|
||||||
|
self.rmsNormEps = try container.decode(
|
||||||
|
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
|
||||||
|
self.vocabularySize = try container.decode(
|
||||||
|
Int.self, forKey: LlamaConfiguration.CodingKeys.vocabularySize)
|
||||||
|
self.kvHeads = try container.decode(Int.self, forKey: LlamaConfiguration.CodingKeys.kvHeads)
|
||||||
|
self.ropeTheta =
|
||||||
|
try container.decodeIfPresent(
|
||||||
|
Float.self, forKey: LlamaConfiguration.CodingKeys.ropeTheta)
|
||||||
|
?? 10_000
|
||||||
|
self.ropeTraditional =
|
||||||
|
try container.decodeIfPresent(
|
||||||
|
Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false
|
||||||
|
self.ropeScaling = try container.decodeIfPresent(
|
||||||
|
[String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
302
Libraries/LLM/Phi.swift
Normal file
302
Libraries/LLM/Phi.swift
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXNN
|
||||||
|
|
||||||
|
// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py
|
||||||
|
|
||||||
|
// TODO: remove once open classes are in
|
||||||
|
|
||||||
|
public class MLXLayerNorm: Module, UnaryLayer {
|
||||||
|
|
||||||
|
let dimensions: Int
|
||||||
|
let eps: Float
|
||||||
|
|
||||||
|
let weight: MLXArray?
|
||||||
|
let bias: MLXArray?
|
||||||
|
|
||||||
|
/// Applies layer normalization [1] on the inputs.
|
||||||
|
///
|
||||||
|
/// See [LayerNorm python docs](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.LayerNorm.html) for more information.
|
||||||
|
///
|
||||||
|
/// ### References
|
||||||
|
/// 1. [https://arxiv.org/abs/1607.06450](https://arxiv.org/abs/1607.06450)
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - dimensions: number of features in the input
|
||||||
|
/// - eps: value added to the denominator for numerical stability
|
||||||
|
/// - affine: if `true` adds a trainable `weight` and `bias`
|
||||||
|
public init(dimensions: Int, eps: Float = 1e-5, affine: Bool = true) {
|
||||||
|
self.dimensions = dimensions
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if affine {
|
||||||
|
self.weight = MLXArray.ones([dimensions])
|
||||||
|
self.bias = MLXArray.zeros([dimensions])
|
||||||
|
} else {
|
||||||
|
self.weight = nil
|
||||||
|
self.bias = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
let means = mean(x, axis: -1, keepDims: true)
|
||||||
|
let variance = variance(x, axis: -1, keepDims: true)
|
||||||
|
let x = (x - means) * rsqrt(variance + eps)
|
||||||
|
|
||||||
|
if let weight, let bias {
|
||||||
|
return weight * x + bias
|
||||||
|
} else {
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class LayerNorm: MLXLayerNorm {
|
||||||
|
override func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
super.callAsFunction(x.asType(Float.self)).asType(x.dtype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class PhiAttention: Module {
|
||||||
|
|
||||||
|
let args: PhiConfiguration
|
||||||
|
let heads: Int
|
||||||
|
let headDim: Int
|
||||||
|
let repeats: Int
|
||||||
|
|
||||||
|
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||||
|
@ModuleInfo(key: "k_proj") var wk: Linear
|
||||||
|
@ModuleInfo(key: "v_proj") var wv: Linear
|
||||||
|
@ModuleInfo(key: "dense") var dense: Linear
|
||||||
|
|
||||||
|
let rope: RoPE
|
||||||
|
|
||||||
|
public init(_ args: PhiConfiguration) {
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
let hiddenSize = args.hiddenSize
|
||||||
|
self.heads = args.attentionHeads
|
||||||
|
self.headDim = args.hiddenSize / heads
|
||||||
|
let kvHeads = args.kvHeads
|
||||||
|
self.repeats = heads / kvHeads
|
||||||
|
|
||||||
|
if headDim * heads != hiddenSize {
|
||||||
|
fatalError("hidden_size must be divisible by num_heads")
|
||||||
|
}
|
||||||
|
|
||||||
|
self._wq.wrappedValue = Linear(hiddenSize, heads * headDim, bias: true)
|
||||||
|
self._wk.wrappedValue = Linear(hiddenSize, kvHeads * headDim, bias: true)
|
||||||
|
self._wv.wrappedValue = Linear(hiddenSize, kvHeads * headDim, bias: true)
|
||||||
|
self._dense.wrappedValue = Linear(heads * headDim, hiddenSize, bias: true)
|
||||||
|
|
||||||
|
self.rope = RoPE(
|
||||||
|
dimensions: Int(args.partialRotaryFactor * Float(headDim)), traditional: false,
|
||||||
|
base: args.ropeTheta)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
|
let (B, L) = (x.dim(0), x.dim(1))
|
||||||
|
|
||||||
|
var queries = wq(x)
|
||||||
|
var keys = wk(x)
|
||||||
|
var values = wv(x)
|
||||||
|
|
||||||
|
// prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3)
|
||||||
|
keys = keys.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
|
||||||
|
values = values.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if repeats > 1 {
|
||||||
|
keys = MLXArray.repeat(keys, count: repeats, axis: 1)
|
||||||
|
values = MLXArray.repeat(values, count: repeats, axis: 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add RoPE to the queries and keys and combine them with the cache
|
||||||
|
if let (keyCache, valueCache) = cache {
|
||||||
|
queries = rope(queries, offset: keyCache.dim(2))
|
||||||
|
keys = rope(keys, offset: keyCache.dim(2))
|
||||||
|
keys = concatenated([keyCache, keys], axis: 2)
|
||||||
|
values = concatenated([valueCache, values], axis: 2)
|
||||||
|
} else {
|
||||||
|
queries = rope(queries)
|
||||||
|
keys = rope(keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
queries = queries.asType(Float.self)
|
||||||
|
keys = keys.asType(Float.self)
|
||||||
|
|
||||||
|
// Finally perform the attention computation
|
||||||
|
let scale = sqrt(1 / Float(queries.dim(-1)))
|
||||||
|
var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2))
|
||||||
|
if let mask {
|
||||||
|
scores = scores + mask
|
||||||
|
}
|
||||||
|
|
||||||
|
scores = softMax(scores, axis: -1).asType(values.dtype)
|
||||||
|
let valuesHat = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
||||||
|
|
||||||
|
return (dense(valuesHat), (keys, values))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class PhiMLP: Module, UnaryLayer {
|
||||||
|
|
||||||
|
@ModuleInfo var fc1: Linear
|
||||||
|
@ModuleInfo var fc2: Linear
|
||||||
|
@ModuleInfo var act: GELU
|
||||||
|
|
||||||
|
public init(_ config: PhiConfiguration) {
|
||||||
|
self.fc1 = Linear(config.hiddenSize, config.intermediateSize)
|
||||||
|
self.fc2 = Linear(config.intermediateSize, config.hiddenSize)
|
||||||
|
self.act = GELU(approximation: .precise)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
fc2(act(fc1(x)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class PhiDecoderLayer: Module {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "self_attn") var selfAttention: PhiAttention
|
||||||
|
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: LayerNorm
|
||||||
|
var mlp: PhiMLP
|
||||||
|
|
||||||
|
public init(_ config: PhiConfiguration) {
|
||||||
|
self._selfAttention.wrappedValue = PhiAttention(config)
|
||||||
|
self._inputLayerNorm.wrappedValue = LayerNorm(
|
||||||
|
dimensions: config.hiddenSize, eps: config.layerNormEps)
|
||||||
|
self.mlp = PhiMLP(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
|
let h = inputLayerNorm(x)
|
||||||
|
let (attentionH, cache) = selfAttention(h, mask: mask, cache: cache)
|
||||||
|
let ffH = mlp(h)
|
||||||
|
return (attentionH + ffH + x, cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class PhiModelInner: Module {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
||||||
|
|
||||||
|
@ModuleInfo var layers: [PhiDecoderLayer]
|
||||||
|
@ModuleInfo(key: "final_layernorm") var finalLayerNorm: LayerNorm
|
||||||
|
|
||||||
|
public init(_ args: PhiConfiguration) {
|
||||||
|
self._embedTokens.wrappedValue = Embedding(
|
||||||
|
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
||||||
|
|
||||||
|
self.layers = (0 ..< args.hiddenLayers)
|
||||||
|
.map { _ in
|
||||||
|
PhiDecoderLayer(args)
|
||||||
|
}
|
||||||
|
self._finalLayerNorm.wrappedValue = LayerNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.layerNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: [(MLXArray, MLXArray)]? = nil
|
||||||
|
) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
) {
|
||||||
|
var x = embedTokens(x)
|
||||||
|
|
||||||
|
var newCache = [(MLXArray, MLXArray)]()
|
||||||
|
|
||||||
|
for (i, layer) in layers.enumerated() {
|
||||||
|
var cacheUpdate: (MLXArray, MLXArray)
|
||||||
|
(x, cacheUpdate) = layer(x, mask: mask, cache: cache?[i])
|
||||||
|
newCache.append(cacheUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (finalLayerNorm(x), newCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class PhiModel: Module, LLMModel {
|
||||||
|
|
||||||
|
fileprivate let model: PhiModelInner
|
||||||
|
|
||||||
|
@ModuleInfo(key: "lm_head") var lmHead: Linear
|
||||||
|
|
||||||
|
public init(_ args: PhiConfiguration) {
|
||||||
|
self.model = PhiModelInner(args)
|
||||||
|
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ x: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
) {
|
||||||
|
var mask: MLXArray? = nil
|
||||||
|
if x.dim(1) > 1 {
|
||||||
|
mask = MultiHeadAttention.createAdditiveCausalMask(x.dim(1))
|
||||||
|
mask = mask?.asType(x.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
let (y, cache) = model(x, mask: mask, cache: cache)
|
||||||
|
return (lmHead(y), cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public struct PhiConfiguration: Codable {
|
||||||
|
var maxPositionalEmbeddings = 2048
|
||||||
|
var vocabularySize = 51200
|
||||||
|
var hiddenSize = 2560
|
||||||
|
var attentionHeads = 32
|
||||||
|
var hiddenLayers = 32
|
||||||
|
var kvHeads = 32
|
||||||
|
var partialRotaryFactor: Float = 0.4
|
||||||
|
var intermediateSize = 10240
|
||||||
|
var layerNormEps: Float = 1e-5
|
||||||
|
var ropeTheta: Float = 10_000
|
||||||
|
|
||||||
|
enum CodingKeys: String, CodingKey {
|
||||||
|
case maxPositionalEmbeddings = "max_position_embeddings"
|
||||||
|
case vocabularySize = "vocab_size"
|
||||||
|
case hiddenSize = "hidden_size"
|
||||||
|
case attentionHeads = "num_attention_heads"
|
||||||
|
case hiddenLayers = "num_hidden_layers"
|
||||||
|
case kvHeads = "num_key_value_heads"
|
||||||
|
case partialRotaryFactor = "partial_rotary_factor"
|
||||||
|
case intermediateSize = "intermediate_size"
|
||||||
|
case layerNormEps = "layer_norm_eps"
|
||||||
|
case ropeTheta = "rope_theta"
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(from decoder: Decoder) throws {
|
||||||
|
let container: KeyedDecodingContainer<PhiConfiguration.CodingKeys> = try decoder.container(
|
||||||
|
keyedBy: PhiConfiguration.CodingKeys.self)
|
||||||
|
|
||||||
|
self.maxPositionalEmbeddings = try container.decode(
|
||||||
|
Int.self, forKey: PhiConfiguration.CodingKeys.maxPositionalEmbeddings)
|
||||||
|
self.vocabularySize = try container.decode(
|
||||||
|
Int.self, forKey: PhiConfiguration.CodingKeys.vocabularySize)
|
||||||
|
self.hiddenSize = try container.decode(
|
||||||
|
Int.self, forKey: PhiConfiguration.CodingKeys.hiddenSize)
|
||||||
|
self.attentionHeads = try container.decode(
|
||||||
|
Int.self, forKey: PhiConfiguration.CodingKeys.attentionHeads)
|
||||||
|
self.hiddenLayers = try container.decode(
|
||||||
|
Int.self, forKey: PhiConfiguration.CodingKeys.hiddenLayers)
|
||||||
|
self.kvHeads =
|
||||||
|
try container.decodeIfPresent(Int.self, forKey: PhiConfiguration.CodingKeys.kvHeads)
|
||||||
|
?? attentionHeads
|
||||||
|
self.partialRotaryFactor = try container.decode(
|
||||||
|
Float.self, forKey: PhiConfiguration.CodingKeys.partialRotaryFactor)
|
||||||
|
self.intermediateSize = try container.decode(
|
||||||
|
Int.self, forKey: PhiConfiguration.CodingKeys.intermediateSize)
|
||||||
|
self.layerNormEps = try container.decode(
|
||||||
|
Float.self, forKey: PhiConfiguration.CodingKeys.layerNormEps)
|
||||||
|
self.ropeTheta =
|
||||||
|
try container.decodeIfPresent(Float.self, forKey: PhiConfiguration.CodingKeys.ropeTheta)
|
||||||
|
?? 10_000
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
11
Libraries/LLM/README.md
Normal file
11
Libraries/LLM/README.md
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# Llama
|
||||||
|
|
||||||
|
This is a port of the llama model from:
|
||||||
|
|
||||||
|
- https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
|
||||||
|
|
||||||
|
You can use this to load models from huggingface, e.g.:
|
||||||
|
|
||||||
|
- https://huggingface.co/mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||||
|
|
||||||
|
See [llm-tool](../../Tools/llm-tool)
|
||||||
110
Libraries/LLM/Util.swift
Normal file
110
Libraries/LLM/Util.swift
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import AsyncAlgorithms
|
||||||
|
import Foundation
|
||||||
|
import Hub
|
||||||
|
import MLX
|
||||||
|
import MLXNN
|
||||||
|
import MLXRandom
|
||||||
|
import Tokenizers
|
||||||
|
|
||||||
|
/// Load and return the model and tokenizer
|
||||||
|
public func load(
|
||||||
|
hub: HubApi = HubApi(), name: String, progressHandler: @escaping (Progress) -> Void = { _ in }
|
||||||
|
) async throws -> (LLMModel, Tokenizer) {
|
||||||
|
// note: this doesn't have a way to pass the HubApi
|
||||||
|
let tokenizer = try await AutoTokenizer.from(pretrained: name)
|
||||||
|
|
||||||
|
// download the model weights and config
|
||||||
|
let repo = Hub.Repo(id: name)
|
||||||
|
let modelFiles = ["config.json", "weights.00.safetensors"]
|
||||||
|
let modelDirectory = try await hub.snapshot(
|
||||||
|
from: repo, matching: modelFiles, progressHandler: progressHandler)
|
||||||
|
|
||||||
|
// create the model (no weights loaded)
|
||||||
|
let configurationURL = modelDirectory.appending(component: "config.json")
|
||||||
|
let baseConfig = try JSONDecoder().decode(
|
||||||
|
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
|
||||||
|
|
||||||
|
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
|
||||||
|
|
||||||
|
// set up the model
|
||||||
|
if let quantization = baseConfig.quantization {
|
||||||
|
QuantizedLinear.quantize(
|
||||||
|
model: model, groupSize: quantization.groupSize, bits: quantization.bits)
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply the loaded weights
|
||||||
|
let weights = try loadArrays(url: modelDirectory.appending(component: "weights.00.safetensors"))
|
||||||
|
let parameters = ModuleParameters.unflattened(weights)
|
||||||
|
try model.update(parameters: parameters, verify: [.all])
|
||||||
|
eval(model.parameters())
|
||||||
|
|
||||||
|
return (model, tokenizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func sample(logits: MLXArray, temp: Float) -> MLXArray {
|
||||||
|
if temp == 0 {
|
||||||
|
return argMax(logits, axis: -1)
|
||||||
|
} else {
|
||||||
|
return categorical(logits * (1 / temp))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Synchronous generator of tokens.
|
||||||
|
///
|
||||||
|
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
|
||||||
|
public struct TokenIterator: Sequence, IteratorProtocol {
|
||||||
|
let model: LLMModel
|
||||||
|
let temp: Float
|
||||||
|
|
||||||
|
var y: MLXArray
|
||||||
|
var cache: [(MLXArray, MLXArray)]
|
||||||
|
|
||||||
|
var first = true
|
||||||
|
|
||||||
|
public init(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) {
|
||||||
|
self.model = model
|
||||||
|
self.temp = temp
|
||||||
|
self.y = prompt
|
||||||
|
self.cache = []
|
||||||
|
}
|
||||||
|
|
||||||
|
mutating public func next() -> MLXArray? {
|
||||||
|
var logits: MLXArray
|
||||||
|
(logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
|
||||||
|
y = sample(logits: logits[-1, axis: 1], temp: temp)
|
||||||
|
|
||||||
|
return y
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Async generator of tokens.
|
||||||
|
///
|
||||||
|
/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py.
|
||||||
|
///
|
||||||
|
/// Note that because MLXArray is not thread safe this eval's the result and sends the TokenId back
|
||||||
|
/// to the caller.
|
||||||
|
public func generate(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) -> (
|
||||||
|
Task<Void, Never>, AsyncBufferSequence<AsyncChannel<Int>>
|
||||||
|
) {
|
||||||
|
let channel = AsyncChannel<Int>()
|
||||||
|
let buffer = channel.buffer(policy: .bounded(10))
|
||||||
|
|
||||||
|
let task = Task {
|
||||||
|
var y = prompt
|
||||||
|
var cache = [(MLXArray, MLXArray)]()
|
||||||
|
|
||||||
|
while !Task.isCancelled {
|
||||||
|
var logits: MLXArray
|
||||||
|
(logits, cache) = model(
|
||||||
|
expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache)
|
||||||
|
y = sample(logits: logits[-1, axis: 1], temp: temp)
|
||||||
|
eval(y)
|
||||||
|
|
||||||
|
await channel.send(y.item(Int.self))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (task, buffer)
|
||||||
|
}
|
||||||
102
Libraries/MNIST/Files.swift
Normal file
102
Libraries/MNIST/Files.swift
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import Gzip
|
||||||
|
import MLX
|
||||||
|
|
||||||
|
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py
|
||||||
|
|
||||||
|
public enum Use: String, Hashable {
|
||||||
|
case test
|
||||||
|
case training
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum DataKind: String, Hashable {
|
||||||
|
case images
|
||||||
|
case labels
|
||||||
|
}
|
||||||
|
|
||||||
|
public struct FileKind: Hashable, CustomStringConvertible {
|
||||||
|
let use: Use
|
||||||
|
let data: DataKind
|
||||||
|
|
||||||
|
public init(_ use: Use, _ data: DataKind) {
|
||||||
|
self.use = use
|
||||||
|
self.data = data
|
||||||
|
}
|
||||||
|
|
||||||
|
public var description: String {
|
||||||
|
"\(use.rawValue)-\(data.rawValue)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LoadInfo {
|
||||||
|
let name: String
|
||||||
|
let offset: Int
|
||||||
|
let convert: (MLXArray) -> MLXArray
|
||||||
|
}
|
||||||
|
|
||||||
|
let baseURL = URL(string: "http://yann.lecun.com/exdb/mnist/")!
|
||||||
|
|
||||||
|
let files = [
|
||||||
|
FileKind(.training, .images): LoadInfo(
|
||||||
|
name: "train-images-idx3-ubyte.gz",
|
||||||
|
offset: 16,
|
||||||
|
convert: {
|
||||||
|
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
|
||||||
|
}),
|
||||||
|
FileKind(.test, .images): LoadInfo(
|
||||||
|
name: "t10k-images-idx3-ubyte.gz",
|
||||||
|
offset: 16,
|
||||||
|
convert: {
|
||||||
|
$0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0
|
||||||
|
}),
|
||||||
|
FileKind(.training, .labels): LoadInfo(
|
||||||
|
name: "train-labels-idx1-ubyte.gz",
|
||||||
|
offset: 8,
|
||||||
|
convert: {
|
||||||
|
$0.asType(.uint32)
|
||||||
|
}),
|
||||||
|
FileKind(.test, .labels): LoadInfo(
|
||||||
|
name: "t10k-labels-idx1-ubyte.gz",
|
||||||
|
offset: 8,
|
||||||
|
convert: {
|
||||||
|
$0.asType(.uint32)
|
||||||
|
}),
|
||||||
|
]
|
||||||
|
|
||||||
|
public func download(into: URL) async throws {
|
||||||
|
for (_, info) in files {
|
||||||
|
let fileURL = into.appending(component: info.name)
|
||||||
|
if !FileManager.default.fileExists(atPath: fileURL.path()) {
|
||||||
|
print("Download: \(info.name)")
|
||||||
|
let url = baseURL.appending(component: info.name)
|
||||||
|
let (data, response) = try await URLSession.shared.data(from: url)
|
||||||
|
|
||||||
|
guard let httpResponse = response as? HTTPURLResponse else {
|
||||||
|
fatalError("Unable to download \(url), not an http response: \(response)")
|
||||||
|
}
|
||||||
|
guard httpResponse.statusCode == 200 else {
|
||||||
|
fatalError("Unable to download \(url): \(httpResponse)")
|
||||||
|
}
|
||||||
|
|
||||||
|
try data.write(to: fileURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public func load(from: URL) throws -> [FileKind: MLXArray] {
|
||||||
|
var result = [FileKind: MLXArray]()
|
||||||
|
|
||||||
|
for (key, info) in files {
|
||||||
|
let fileURL = from.appending(component: info.name)
|
||||||
|
let data = try Data(contentsOf: fileURL).gunzipped()
|
||||||
|
|
||||||
|
let array = MLXArray(
|
||||||
|
data.dropFirst(info.offset), [data.count - info.offset], type: UInt8.self)
|
||||||
|
|
||||||
|
result[key] = info.convert(array)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
1
Libraries/MNIST/MNIST.h
Normal file
1
Libraries/MNIST/MNIST.h
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
73
Libraries/MNIST/MNIST.swift
Normal file
73
Libraries/MNIST/MNIST.swift
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXNN
|
||||||
|
|
||||||
|
// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py
|
||||||
|
|
||||||
|
public class MLP: Module, UnaryLayer {
|
||||||
|
|
||||||
|
@ModuleInfo var layers: [Linear]
|
||||||
|
|
||||||
|
public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) {
|
||||||
|
let layerSizes =
|
||||||
|
[inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [
|
||||||
|
outputDimensions
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst())
|
||||||
|
.map {
|
||||||
|
Linear($0, $1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
var x = x
|
||||||
|
for l in layers.dropLast() {
|
||||||
|
x = relu(l(x))
|
||||||
|
}
|
||||||
|
return layers.last!(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||||
|
crossEntropy(logits: model(x), targets: y, reduction: .mean)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||||
|
mean(argMax(model(x), axis: 1) .== y)
|
||||||
|
}
|
||||||
|
|
||||||
|
private struct BatchSequence: Sequence, IteratorProtocol {
|
||||||
|
|
||||||
|
let batchSize: Int
|
||||||
|
let x: MLXArray
|
||||||
|
let y: MLXArray
|
||||||
|
|
||||||
|
let indexes: MLXArray
|
||||||
|
var index = 0
|
||||||
|
|
||||||
|
init(batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator)
|
||||||
|
{
|
||||||
|
self.batchSize = batchSize
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
self.indexes = MLXArray(Array(0 ..< y.size).shuffled(using: &generator))
|
||||||
|
}
|
||||||
|
|
||||||
|
mutating func next() -> (MLXArray, MLXArray)? {
|
||||||
|
guard index < y.size else { return nil }
|
||||||
|
|
||||||
|
let range = index ..< Swift.min(index + batchSize, y.size)
|
||||||
|
index += batchSize
|
||||||
|
let ids = indexes[range]
|
||||||
|
return (x[ids], y[ids])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public func iterateBatches(
|
||||||
|
batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator
|
||||||
|
) -> some Sequence<(MLXArray, MLXArray)> {
|
||||||
|
BatchSequence(batchSize: batchSize, x: x, y: y, using: &generator)
|
||||||
|
}
|
||||||
13
Libraries/MNIST/README.md
Normal file
13
Libraries/MNIST/README.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# MNIST
|
||||||
|
|
||||||
|
This is a port of the MNIST model and training code from:
|
||||||
|
|
||||||
|
- https://github.com/ml-explore/mlx-examples/blob/main/mnist
|
||||||
|
|
||||||
|
It provides code to:
|
||||||
|
|
||||||
|
- download the test/train data
|
||||||
|
- provides the MNIST model (MLP)
|
||||||
|
- some functions to shuffle and batch the data
|
||||||
|
|
||||||
|
See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there.
|
||||||
30
Libraries/MNIST/Random.swift
Normal file
30
Libraries/MNIST/Random.swift
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
// From https://github.com/apple/swift/blob/cb0fb1ea051631219c0b944b84c78571448d58c2/benchmark/utils/TestsUtils.swift#L254
|
||||||
|
//
|
||||||
|
// This is just a seedable RandomNumberGenerator for shuffle()
|
||||||
|
|
||||||
|
// This is a fixed-increment version of Java 8's SplittableRandom generator.
|
||||||
|
// It is a very fast generator passing BigCrush, with 64 bits of state.
|
||||||
|
// See http://dx.doi.org/10.1145/2714064.2660195 and
|
||||||
|
// http://docs.oracle.com/javase/8/docs/api/java/util/SplittableRandom.html
|
||||||
|
//
|
||||||
|
// Derived from public domain C implementation by Sebastiano Vigna
|
||||||
|
// See http://xoshiro.di.unimi.it/splitmix64.c
|
||||||
|
public struct SplitMix64: RandomNumberGenerator {
|
||||||
|
private var state: UInt64
|
||||||
|
|
||||||
|
public init(seed: UInt64) {
|
||||||
|
self.state = seed
|
||||||
|
}
|
||||||
|
|
||||||
|
public mutating func next() -> UInt64 {
|
||||||
|
self.state &+= 0x9e37_79b9_7f4a_7c15
|
||||||
|
var z: UInt64 = self.state
|
||||||
|
z = (z ^ (z &>> 30)) &* 0xbf58_476d_1ce4_e5b9
|
||||||
|
z = (z ^ (z &>> 27)) &* 0x94d0_49bb_1331_11eb
|
||||||
|
return z ^ (z &>> 31)
|
||||||
|
}
|
||||||
|
}
|
||||||
22
README.md
Normal file
22
README.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# mlx-examples-swift
|
||||||
|
|
||||||
|
Example mlx-swift programs.
|
||||||
|
|
||||||
|
## LinearModelTraining
|
||||||
|
|
||||||
|
A simple linear model and a training loop.
|
||||||
|
|
||||||
|
- [README](Tools/LinearModelTraining/README.md)
|
||||||
|
|
||||||
|
## llm-tool
|
||||||
|
|
||||||
|
A command line tool for generating text using a Llama / Mistral model:
|
||||||
|
|
||||||
|
- [README](Tools/llm-tool/README.md)
|
||||||
|
|
||||||
|
## mnist-tool
|
||||||
|
|
||||||
|
A command line tool for training an MNIST (MLP) model:
|
||||||
|
|
||||||
|
- [README](Tools/mnist-tool/README.md)
|
||||||
|
|
||||||
113
Tools/LinearModelTraining/LinearModelTraining.swift
Normal file
113
Tools/LinearModelTraining/LinearModelTraining.swift
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import ArgumentParser
|
||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXNN
|
||||||
|
import MLXOptimizers
|
||||||
|
import MLXRandom
|
||||||
|
|
||||||
|
extension MLX.DeviceType: ExpressibleByArgument {
|
||||||
|
public init?(argument: String) {
|
||||||
|
self.init(rawValue: argument)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@main
|
||||||
|
struct Train: AsyncParsableCommand {
|
||||||
|
|
||||||
|
@Option var epochs = 20
|
||||||
|
@Option var batchSize = 8
|
||||||
|
|
||||||
|
@Option var m: Float = 0.25
|
||||||
|
@Option var b: Float = 7
|
||||||
|
|
||||||
|
@Flag var compile = false
|
||||||
|
|
||||||
|
@Option var device = DeviceType.cpu
|
||||||
|
|
||||||
|
func run() async throws {
|
||||||
|
Device.setDefault(device: Device(device))
|
||||||
|
|
||||||
|
// A very simple model that implements the equation
|
||||||
|
// for a linear function: y = mx + b. This can be trained
|
||||||
|
// to match data -- in this case an unknown (to the model)
|
||||||
|
// linear function.
|
||||||
|
//
|
||||||
|
// This is a nice example because most people know how
|
||||||
|
// linear functions work and we can see how the slope
|
||||||
|
// and intercept converge.
|
||||||
|
class LinearFunctionModel: Module, UnaryLayer {
|
||||||
|
let m = MLXRandom.uniform(low: -5.0, high: 5.0)
|
||||||
|
let b = MLXRandom.uniform(low: -5.0, high: 5.0)
|
||||||
|
|
||||||
|
func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
m * x + b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// measure the distance from the prediction (model(x)) and the
|
||||||
|
// ground truth (y). this gives feedback on how close the
|
||||||
|
// prediction is from matching the truth
|
||||||
|
func loss(model: LinearFunctionModel, x: MLXArray, y: MLXArray) -> MLXArray {
|
||||||
|
mseLoss(predictions: model(x), targets: y, reduction: .mean)
|
||||||
|
}
|
||||||
|
|
||||||
|
let model = LinearFunctionModel()
|
||||||
|
eval(model.parameters())
|
||||||
|
|
||||||
|
let lg = valueAndGrad(model: model, loss)
|
||||||
|
|
||||||
|
// the optimizer will use the gradients update the model parameters
|
||||||
|
let optimizer = SGD(learningRate: 1e-1)
|
||||||
|
|
||||||
|
// the function to train our model against -- it doesn't have
|
||||||
|
// to be linear, but matching what the model models is easy
|
||||||
|
// to understand
|
||||||
|
func f(_ x: MLXArray) -> MLXArray {
|
||||||
|
// these are the target parameters
|
||||||
|
let m = self.m
|
||||||
|
let b = self.b
|
||||||
|
|
||||||
|
// our actual function
|
||||||
|
return m * x + b
|
||||||
|
}
|
||||||
|
|
||||||
|
func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
|
||||||
|
let (loss, grads) = lg(model, x, y)
|
||||||
|
optimizer.update(model: model, gradients: grads)
|
||||||
|
return loss
|
||||||
|
}
|
||||||
|
|
||||||
|
let resolvedStep =
|
||||||
|
self.compile
|
||||||
|
? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
|
||||||
|
|
||||||
|
for _ in 0 ..< epochs {
|
||||||
|
// we expect that the parameters will approach the targets
|
||||||
|
print("target: b = \(b), m = \(m)")
|
||||||
|
print("parameters: \(model.parameters())")
|
||||||
|
|
||||||
|
// generate random training data along with the ground truth.
|
||||||
|
// notice that the shape is [B, 1] where B is the batch
|
||||||
|
// dimension -- this allows us to train on several samples simultaneously
|
||||||
|
//
|
||||||
|
// note: a very large batch size will take longer to converge because
|
||||||
|
// the gradient will be representing too many samples down into
|
||||||
|
// a single float parameter.
|
||||||
|
let x = MLXRandom.uniform(low: -5.0, high: 5.0, [batchSize, 1])
|
||||||
|
let y = f(x)
|
||||||
|
eval(x, y)
|
||||||
|
|
||||||
|
// compute the loss and gradients. use the optimizer
|
||||||
|
// to adjust the parameters closer to the target
|
||||||
|
let loss = resolvedStep(x, y)
|
||||||
|
|
||||||
|
eval(model, optimizer)
|
||||||
|
|
||||||
|
// we should see this converge toward 0
|
||||||
|
print("loss: \(loss)")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
14
Tools/LinearModelTraining/README.md
Normal file
14
Tools/LinearModelTraining/README.md
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# LinearModelTraining
|
||||||
|
|
||||||
|
A command line tool that creates a Model that represents:
|
||||||
|
|
||||||
|
f(x) = mx + b
|
||||||
|
|
||||||
|
and trains it against an unknown linear function. Very
|
||||||
|
simple but illustrates:
|
||||||
|
|
||||||
|
- a very simple model with parameters
|
||||||
|
- a loss function
|
||||||
|
- the gradient
|
||||||
|
- use of an optimizers
|
||||||
|
- the training loop
|
||||||
102
Tools/Tutorial/Tutorial.swift
Normal file
102
Tools/Tutorial/Tutorial.swift
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
|
||||||
|
/// mlx-swift tutorial based on:
|
||||||
|
/// https://github.com/ml-explore/mlx/blob/main/examples/cpp/tutorial.cpp
|
||||||
|
@main
|
||||||
|
struct Tutorial {
|
||||||
|
|
||||||
|
static func scalarBasics() {
|
||||||
|
// create a scalar array
|
||||||
|
let x = MLXArray(1.0)
|
||||||
|
|
||||||
|
// the datatype is .float32
|
||||||
|
let dtype = x.dtype
|
||||||
|
assert(dtype == .float32)
|
||||||
|
|
||||||
|
// get the value
|
||||||
|
let s = x.item(Float.self)
|
||||||
|
assert(s == 1.0)
|
||||||
|
|
||||||
|
// reading the value with a different type is a fatal error
|
||||||
|
// let i = x.item(Int.self)
|
||||||
|
|
||||||
|
// scalars have a size of 1
|
||||||
|
let size = x.size
|
||||||
|
assert(size == 1)
|
||||||
|
|
||||||
|
// scalars have 0 dimensions
|
||||||
|
let ndim = x.ndim
|
||||||
|
assert(ndim == 0)
|
||||||
|
|
||||||
|
// scalar shapes are empty arrays
|
||||||
|
let shape = x.shape
|
||||||
|
assert(shape == [])
|
||||||
|
}
|
||||||
|
|
||||||
|
static func arrayBasics() {
|
||||||
|
// make a multidimensional array.
|
||||||
|
//
|
||||||
|
// Note: the argument is a [Double] array literal, which is not
|
||||||
|
// a supported type, but we can explicitly convert it to [Float]
|
||||||
|
// when we create the MLXArray.
|
||||||
|
let x = MLXArray(converting: [1.0, 2.0, 3.0, 4.0], [2, 2])
|
||||||
|
|
||||||
|
// mlx is row-major by default so the first row of this array
|
||||||
|
// is [1.0, 2.0] and the second row is [3.0, 4.0]
|
||||||
|
print(x[0])
|
||||||
|
print(x[1])
|
||||||
|
|
||||||
|
// make an array of shape [2, 2] filled with ones
|
||||||
|
let y = MLXArray.ones([2, 2])
|
||||||
|
|
||||||
|
// pointwise add x and y
|
||||||
|
let z = x + y
|
||||||
|
|
||||||
|
// mlx is lazy by default. At this point `z` only
|
||||||
|
// has a shape and a type but no actual data
|
||||||
|
assert(z.dtype == .float32)
|
||||||
|
assert(z.shape == [2, 2])
|
||||||
|
|
||||||
|
// To actually run the computation you must evaluate `z`.
|
||||||
|
// Under the hood, mlx records operations in a graph.
|
||||||
|
// The variable `z` is a node in the graph which points to its operation
|
||||||
|
// and inputs. When `eval` is called on an array (or arrays), the array and
|
||||||
|
// all of its dependencies are recursively evaluated to produce the result.
|
||||||
|
// Once an array is evaluated, it has data and is detached from its inputs.
|
||||||
|
|
||||||
|
// Note: this is being called for demonstration purposes -- all reads
|
||||||
|
// ensure the array is evaluated.
|
||||||
|
z.eval()
|
||||||
|
|
||||||
|
// this implicitly evaluates z before converting to a description
|
||||||
|
print(z)
|
||||||
|
}
|
||||||
|
|
||||||
|
static func automaticDifferentiation() {
|
||||||
|
func fn(_ x: MLXArray) -> MLXArray {
|
||||||
|
x.square()
|
||||||
|
}
|
||||||
|
|
||||||
|
let gradFn = grad(fn)
|
||||||
|
|
||||||
|
let x = MLXArray(1.5)
|
||||||
|
let dfdx = gradFn(x)
|
||||||
|
print(dfdx)
|
||||||
|
|
||||||
|
assert(dfdx.item() == Float(2 * 1.5))
|
||||||
|
|
||||||
|
let df2dx2 = grad(grad(fn))(x)
|
||||||
|
print(df2dx2)
|
||||||
|
|
||||||
|
assert(df2dx2.item() == Float(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
static func main() {
|
||||||
|
scalarBasics()
|
||||||
|
arrayBasics()
|
||||||
|
automaticDifferentiation()
|
||||||
|
}
|
||||||
|
}
|
||||||
190
Tools/llm-tool/LLMTool.swift
Normal file
190
Tools/llm-tool/LLMTool.swift
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import ArgumentParser
|
||||||
|
import Foundation
|
||||||
|
import LLM
|
||||||
|
import MLX
|
||||||
|
import MLXRandom
|
||||||
|
|
||||||
|
struct LLMTool: AsyncParsableCommand {
|
||||||
|
static var configuration = CommandConfiguration(
|
||||||
|
abstract: "Command line tool for generating text using Llama models",
|
||||||
|
subcommands: [SyncGenerator.self, AsyncGenerator.self],
|
||||||
|
defaultSubcommand: SyncGenerator.self)
|
||||||
|
}
|
||||||
|
|
||||||
|
@main
|
||||||
|
struct SyncGenerator: AsyncParsableCommand {
|
||||||
|
|
||||||
|
static var configuration = CommandConfiguration(
|
||||||
|
commandName: "sync",
|
||||||
|
abstract: "Synchronous generator"
|
||||||
|
)
|
||||||
|
|
||||||
|
@Option(name: .long, help: "Name of the huggingface model")
|
||||||
|
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||||
|
|
||||||
|
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||||
|
var prompt = "compare swift and python"
|
||||||
|
|
||||||
|
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||||
|
var maxTokens = 100
|
||||||
|
|
||||||
|
@Option(name: .shortAndLong, help: "The sampling temperature")
|
||||||
|
var temperature: Float = 0.0
|
||||||
|
|
||||||
|
@Option(name: .long, help: "The PRNG seed")
|
||||||
|
var seed: UInt64 = 0
|
||||||
|
|
||||||
|
@MainActor
|
||||||
|
func run() async throws {
|
||||||
|
MLXRandom.seed(seed)
|
||||||
|
|
||||||
|
let (model, tokenizer) = try await load(name: model)
|
||||||
|
|
||||||
|
print("Starting generation ...")
|
||||||
|
print(prompt, terminator: "")
|
||||||
|
|
||||||
|
var start = Date.timeIntervalSinceReferenceDate
|
||||||
|
var promptTime: TimeInterval = 0
|
||||||
|
|
||||||
|
let prompt = MLXArray(tokenizer.encode(text: prompt))
|
||||||
|
|
||||||
|
// collect the tokens and keep track of how much of the string
|
||||||
|
// we have printed already
|
||||||
|
var tokens = [Int]()
|
||||||
|
var printed = 0
|
||||||
|
|
||||||
|
for token in TokenIterator(prompt: prompt, model: model, temp: temperature) {
|
||||||
|
if tokens.isEmpty {
|
||||||
|
eval(token)
|
||||||
|
let now = Date.timeIntervalSinceReferenceDate
|
||||||
|
promptTime = now - start
|
||||||
|
start = now
|
||||||
|
}
|
||||||
|
|
||||||
|
let t = token.item(Int.self)
|
||||||
|
if t == tokenizer.unknownTokenId {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
tokens.append(t)
|
||||||
|
|
||||||
|
// print any new parts of the string
|
||||||
|
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||||
|
let emitLength = fullOutput.count - printed
|
||||||
|
let suffix = fullOutput.suffix(emitLength)
|
||||||
|
print(suffix, terminator: "")
|
||||||
|
fflush(stdout)
|
||||||
|
|
||||||
|
printed = fullOutput.count
|
||||||
|
|
||||||
|
if tokens.count == maxTokens {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("------")
|
||||||
|
let now = Date.timeIntervalSinceReferenceDate
|
||||||
|
let generateTime = now - start
|
||||||
|
|
||||||
|
print(
|
||||||
|
"""
|
||||||
|
Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted())
|
||||||
|
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Example of an async generator.
|
||||||
|
///
|
||||||
|
/// Note that all of the computation is done on another thread and TokenId (Int32) are sent
|
||||||
|
/// rather than MLXArray.
|
||||||
|
struct AsyncGenerator: AsyncParsableCommand {
|
||||||
|
|
||||||
|
static var configuration = CommandConfiguration(
|
||||||
|
commandName: "async",
|
||||||
|
abstract: "async generator"
|
||||||
|
)
|
||||||
|
|
||||||
|
@Option(name: .long, help: "Name of the huggingface model")
|
||||||
|
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||||
|
|
||||||
|
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||||
|
var prompt = "compare swift and python"
|
||||||
|
|
||||||
|
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||||
|
var maxTokens = 100
|
||||||
|
|
||||||
|
@Option(name: .shortAndLong, help: "The sampling temperature")
|
||||||
|
var temperature: Float = 0.0
|
||||||
|
|
||||||
|
@Option(name: .long, help: "The PRNG seed")
|
||||||
|
var seed: UInt64 = 0
|
||||||
|
|
||||||
|
@MainActor
|
||||||
|
func run() async throws {
|
||||||
|
MLXRandom.seed(seed)
|
||||||
|
|
||||||
|
let (model, tokenizer) = try await load(name: model)
|
||||||
|
|
||||||
|
print("Starting generation ...")
|
||||||
|
print(prompt, terminator: "")
|
||||||
|
|
||||||
|
var start = Date.timeIntervalSinceReferenceDate
|
||||||
|
var promptTime: TimeInterval = 0
|
||||||
|
|
||||||
|
let prompt = MLXArray(tokenizer.encode(text: prompt))
|
||||||
|
|
||||||
|
// collect the tokens and keep track of how much of the string
|
||||||
|
// we have printed already
|
||||||
|
var tokens = [Int]()
|
||||||
|
var printed = 0
|
||||||
|
|
||||||
|
let (task, channel) = generate(prompt: prompt, model: model, temp: temperature)
|
||||||
|
|
||||||
|
for await token in channel {
|
||||||
|
if tokens.isEmpty {
|
||||||
|
let now = Date.timeIntervalSinceReferenceDate
|
||||||
|
promptTime = now - start
|
||||||
|
start = now
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == tokenizer.unknownTokenId {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
tokens.append(token)
|
||||||
|
|
||||||
|
// print any new parts of the string
|
||||||
|
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||||
|
let emitLength = fullOutput.count - printed
|
||||||
|
let suffix = fullOutput.suffix(emitLength)
|
||||||
|
print(suffix, terminator: "")
|
||||||
|
fflush(stdout)
|
||||||
|
|
||||||
|
printed = fullOutput.count
|
||||||
|
|
||||||
|
if tokens.count == maxTokens {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tell the task to stop
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("------")
|
||||||
|
let now = Date.timeIntervalSinceReferenceDate
|
||||||
|
let generateTime = now - start
|
||||||
|
|
||||||
|
print(
|
||||||
|
"""
|
||||||
|
Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted())
|
||||||
|
Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted())
|
||||||
|
""")
|
||||||
|
|
||||||
|
// wait for the task to complete -- since it is running async, it might
|
||||||
|
// be in the middle of running the model
|
||||||
|
try? await Task.sleep(for: .milliseconds(500))
|
||||||
|
}
|
||||||
|
}
|
||||||
38
Tools/llm-tool/README.md
Normal file
38
Tools/llm-tool/README.md
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
# llm-tool
|
||||||
|
|
||||||
|
See various READMEs:
|
||||||
|
|
||||||
|
- [Llama](../../Libraries/Llama/README.md)
|
||||||
|
|
||||||
|
### Building
|
||||||
|
|
||||||
|
Build the `llm-tool` scheme in Xcode.
|
||||||
|
|
||||||
|
### Running (Xcode)
|
||||||
|
|
||||||
|
To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||||
|
--prompt "swift programming language"
|
||||||
|
--max-tokens 50
|
||||||
|
```
|
||||||
|
|
||||||
|
Then cmd-r to run.
|
||||||
|
|
||||||
|
> Note: you may be prompted for access to your Documents directory -- this is where
|
||||||
|
the huggingface HubApi stores the downloaded files.
|
||||||
|
|
||||||
|
### Running (Command Line)
|
||||||
|
|
||||||
|
`llm-tool` can also be run from the command line if built from Xcode, but
|
||||||
|
the `DYLD_FRAMEWORK_PATH` must be set so that the frameworks and bundles can be found:
|
||||||
|
|
||||||
|
- [MLX troubleshooting](https://ml-explore.github.io/mlx-swift/MLX/documentation/mlx/troubleshooting)
|
||||||
|
|
||||||
|
The easiest way to do this is drag the Products/llm-tool into Terminal to get the path:
|
||||||
|
|
||||||
|
```
|
||||||
|
DYLD_FRAMEWORK_PATH=~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug ~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug/llm-tool --prompt "swift programming language"
|
||||||
|
```
|
||||||
|
|
||||||
108
Tools/mnist-tool/MNISTTool.swift
Normal file
108
Tools/mnist-tool/MNISTTool.swift
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import ArgumentParser
|
||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXNN
|
||||||
|
import MLXOptimizers
|
||||||
|
import MLXRandom
|
||||||
|
import MNIST
|
||||||
|
|
||||||
|
@main
|
||||||
|
struct MNISTTool: AsyncParsableCommand {
|
||||||
|
static var configuration = CommandConfiguration(
|
||||||
|
abstract: "Command line tool for training mnist models",
|
||||||
|
subcommands: [Train.self],
|
||||||
|
defaultSubcommand: Train.self)
|
||||||
|
}
|
||||||
|
|
||||||
|
extension MLX.DeviceType: ExpressibleByArgument {
|
||||||
|
public init?(argument: String) {
|
||||||
|
self.init(rawValue: argument)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Train: AsyncParsableCommand {
|
||||||
|
|
||||||
|
@Option(name: .long, help: "Directory with the training data")
|
||||||
|
var data: String
|
||||||
|
|
||||||
|
@Option(name: .long, help: "The PRNG seed")
|
||||||
|
var seed: UInt64 = 0
|
||||||
|
|
||||||
|
@Option var layers = 2
|
||||||
|
@Option var hidden = 32
|
||||||
|
@Option var batchSize = 256
|
||||||
|
@Option var epochs = 20
|
||||||
|
@Option var learningRate: Float = 1e-1
|
||||||
|
|
||||||
|
@Option var classes = 10
|
||||||
|
|
||||||
|
@Option var device = DeviceType.cpu
|
||||||
|
|
||||||
|
@Flag var compile = false
|
||||||
|
|
||||||
|
func run() async throws {
|
||||||
|
Device.setDefault(device: Device(device))
|
||||||
|
|
||||||
|
MLXRandom.seed(seed)
|
||||||
|
var generator: RandomNumberGenerator = SplitMix64(seed: seed)
|
||||||
|
|
||||||
|
// load the data
|
||||||
|
let url = URL(filePath: data)
|
||||||
|
|
||||||
|
try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true)
|
||||||
|
try await download(into: url)
|
||||||
|
|
||||||
|
let data = try load(from: url)
|
||||||
|
|
||||||
|
let trainImages = data[.init(.training, .images)]!
|
||||||
|
let trainLabels = data[.init(.training, .labels)]!
|
||||||
|
let testImages = data[.init(.test, .images)]!
|
||||||
|
let testLabels = data[.init(.test, .labels)]!
|
||||||
|
|
||||||
|
// create the model
|
||||||
|
let model = MLP(
|
||||||
|
layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden,
|
||||||
|
outputDimensions: classes)
|
||||||
|
eval(model.parameters())
|
||||||
|
|
||||||
|
let lg = valueAndGrad(model: model, loss)
|
||||||
|
let optimizer = SGD(learningRate: learningRate)
|
||||||
|
|
||||||
|
func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
|
||||||
|
let (loss, grads) = lg(model, x, y)
|
||||||
|
optimizer.update(model: model, gradients: grads)
|
||||||
|
return loss
|
||||||
|
}
|
||||||
|
|
||||||
|
let resolvedStep =
|
||||||
|
compile
|
||||||
|
? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
|
||||||
|
|
||||||
|
for e in 0 ..< epochs {
|
||||||
|
let start = Date.timeIntervalSinceReferenceDate
|
||||||
|
|
||||||
|
for (x, y) in iterateBatches(
|
||||||
|
batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator)
|
||||||
|
{
|
||||||
|
_ = resolvedStep(x, y)
|
||||||
|
|
||||||
|
// eval the parameters so the next iteration is independent
|
||||||
|
eval(model, optimizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
let accuracy = eval(model: model, x: testImages, y: testLabels)
|
||||||
|
|
||||||
|
let end = Date.timeIntervalSinceReferenceDate
|
||||||
|
|
||||||
|
print(
|
||||||
|
"""
|
||||||
|
Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted())
|
||||||
|
Time: \((end - start).formatted())
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
36
Tools/mnist-tool/README.md
Normal file
36
Tools/mnist-tool/README.md
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# mnist-tool
|
||||||
|
|
||||||
|
See other README:
|
||||||
|
|
||||||
|
- [MNIST](../../Libraries/MNIST/README.md)
|
||||||
|
|
||||||
|
### Building
|
||||||
|
|
||||||
|
`mnist-tool` has no dependencies outside of the package dependencies
|
||||||
|
represented in xcode.
|
||||||
|
|
||||||
|
When you run the tool it will download the test/train datasets and
|
||||||
|
store them in a specified directory (see run arguments -- default is /tmp).
|
||||||
|
|
||||||
|
Simply build the project in xcode.
|
||||||
|
|
||||||
|
### Running (Xcode)
|
||||||
|
|
||||||
|
To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
--data /tmp
|
||||||
|
```
|
||||||
|
|
||||||
|
Then cmd-r to run.
|
||||||
|
|
||||||
|
### Running (CommandLine)
|
||||||
|
|
||||||
|
`mnist-tool` can also be run from the command line if built from Xcode, but
|
||||||
|
the `DYLD_FRAMEWORK_PATH` must be set so that the frameworks and bundles can be found:
|
||||||
|
|
||||||
|
- [MLX troubleshooting](https://ml-explore.github.io/mlx-swift/MLX/documentation/mlx/troubleshooting)
|
||||||
|
|
||||||
|
```
|
||||||
|
DYLD_FRAMEWORK_PATH=~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug ~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug/mnist-tool --data /tmp
|
||||||
|
```
|
||||||
1716
mlx-swift-examples.xcodeproj/project.pbxproj
Normal file
1716
mlx-swift-examples.xcodeproj/project.pbxproj
Normal file
File diff suppressed because it is too large
Load Diff
7
mlx-swift-examples.xcodeproj/project.xcworkspace/contents.xcworkspacedata
generated
Normal file
7
mlx-swift-examples.xcodeproj/project.xcworkspace/contents.xcworkspacedata
generated
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<Workspace
|
||||||
|
version = "1.0">
|
||||||
|
<FileRef
|
||||||
|
location = "self:">
|
||||||
|
</FileRef>
|
||||||
|
</Workspace>
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||||
|
<plist version="1.0">
|
||||||
|
<dict>
|
||||||
|
<key>IDEDidComputeMac32BitWarning</key>
|
||||||
|
<true/>
|
||||||
|
</dict>
|
||||||
|
</plist>
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
{
|
||||||
|
"pins" : [
|
||||||
|
{
|
||||||
|
"identity" : "gzipswift",
|
||||||
|
"kind" : "remoteSourceControl",
|
||||||
|
"location" : "https://github.com/1024jp/GzipSwift",
|
||||||
|
"state" : {
|
||||||
|
"revision" : "731037f6cc2be2ec01562f6597c1d0aa3fe6fd05",
|
||||||
|
"version" : "6.0.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identity" : "mlx-swift",
|
||||||
|
"kind" : "remoteSourceControl",
|
||||||
|
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||||
|
"state" : {
|
||||||
|
"branch" : "main",
|
||||||
|
"revision" : "cadf5f8187ac0894e66cd288217e2eda9f2c933d"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identity" : "swift-argument-parser",
|
||||||
|
"kind" : "remoteSourceControl",
|
||||||
|
"location" : "https://github.com/apple/swift-argument-parser.git",
|
||||||
|
"state" : {
|
||||||
|
"revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41",
|
||||||
|
"version" : "1.3.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identity" : "swift-async-algorithms",
|
||||||
|
"kind" : "remoteSourceControl",
|
||||||
|
"location" : "https://github.com/apple/swift-async-algorithms",
|
||||||
|
"state" : {
|
||||||
|
"revision" : "da4e36f86544cdf733a40d59b3a2267e3a7bbf36",
|
||||||
|
"version" : "1.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identity" : "swift-collections",
|
||||||
|
"kind" : "remoteSourceControl",
|
||||||
|
"location" : "https://github.com/apple/swift-collections.git",
|
||||||
|
"state" : {
|
||||||
|
"revision" : "d029d9d39c87bed85b1c50adee7c41795261a192",
|
||||||
|
"version" : "1.0.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identity" : "swift-numerics",
|
||||||
|
"kind" : "remoteSourceControl",
|
||||||
|
"location" : "https://github.com/apple/swift-numerics",
|
||||||
|
"state" : {
|
||||||
|
"revision" : "0a5bc04095a675662cf24757cc0640aa2204253b",
|
||||||
|
"version" : "1.0.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"identity" : "swift-transformers",
|
||||||
|
"kind" : "remoteSourceControl",
|
||||||
|
"location" : "https://github.com/huggingface/swift-transformers",
|
||||||
|
"state" : {
|
||||||
|
"revision" : "564442fba36b0b694d730a62d0593e5f54043b55",
|
||||||
|
"version" : "0.1.2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"version" : 2
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user