From b6d1e14465b72b52b8d15a59337120814aec259b Mon Sep 17 00:00:00 2001 From: David Koski Date: Thu, 22 Feb 2024 10:41:02 -0800 Subject: [PATCH] initial commit --- .circleci/config.yml | 63 + .gitignore | 90 + .pre-commit-config.yaml | 6 + .swift-format | 7 + Libraries/LLM/Configuration.swift | 77 + Libraries/LLM/Gemma.swift | 273 +++ Libraries/LLM/LLM.h | 1 + Libraries/LLM/LLMModel.swift | 12 + Libraries/LLM/Llama.swift | 263 +++ Libraries/LLM/Phi.swift | 302 +++ Libraries/LLM/README.md | 11 + Libraries/LLM/Util.swift | 110 ++ Libraries/MNIST/Files.swift | 102 + Libraries/MNIST/MNIST.h | 1 + Libraries/MNIST/MNIST.swift | 73 + Libraries/MNIST/README.md | 13 + Libraries/MNIST/Random.swift | 30 + README.md | 22 + .../LinearModelTraining.swift | 113 ++ Tools/LinearModelTraining/README.md | 14 + Tools/Tutorial/Tutorial.swift | 102 + Tools/llm-tool/LLMTool.swift | 190 ++ Tools/llm-tool/README.md | 38 + Tools/mnist-tool/MNISTTool.swift | 108 ++ Tools/mnist-tool/README.md | 36 + mlx-swift-examples.xcodeproj/project.pbxproj | 1716 +++++++++++++++++ .../contents.xcworkspacedata | 7 + .../xcshareddata/IDEWorkspaceChecks.plist | 8 + .../xcshareddata/swiftpm/Package.resolved | 68 + 29 files changed, 3856 insertions(+) create mode 100644 .circleci/config.yml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 .swift-format create mode 100644 Libraries/LLM/Configuration.swift create mode 100644 Libraries/LLM/Gemma.swift create mode 100644 Libraries/LLM/LLM.h create mode 100644 Libraries/LLM/LLMModel.swift create mode 100644 Libraries/LLM/Llama.swift create mode 100644 Libraries/LLM/Phi.swift create mode 100644 Libraries/LLM/README.md create mode 100644 Libraries/LLM/Util.swift create mode 100644 Libraries/MNIST/Files.swift create mode 100644 Libraries/MNIST/MNIST.h create mode 100644 Libraries/MNIST/MNIST.swift create mode 100644 Libraries/MNIST/README.md create mode 100644 Libraries/MNIST/Random.swift create mode 100644 README.md create mode 100644 Tools/LinearModelTraining/LinearModelTraining.swift create mode 100644 Tools/LinearModelTraining/README.md create mode 100644 Tools/Tutorial/Tutorial.swift create mode 100644 Tools/llm-tool/LLMTool.swift create mode 100644 Tools/llm-tool/README.md create mode 100644 Tools/mnist-tool/MNISTTool.swift create mode 100644 Tools/mnist-tool/README.md create mode 100644 mlx-swift-examples.xcodeproj/project.pbxproj create mode 100644 mlx-swift-examples.xcodeproj/project.xcworkspace/contents.xcworkspacedata create mode 100644 mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist create mode 100644 mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..eaea674 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,63 @@ +version: 2.1 + +orbs: + apple: ml-explore/pr-approval@0.1.0 + +parameters: + nightly_build: + type: boolean + default: false + weekly_build: + type: boolean + default: false + +jobs: + + mac_build_and_test: + macos: + xcode: 15.2.0 + resource_class: macos.m1.medium.gen1 + steps: + - checkout + - run: git submodule sync + - run: git submodule update --init + - run: + name: Run style checks + command: | + pip install pre-commit + brew install swift-format + pre-commit run --all + if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi + - run: + name: Build Examples + command: | + xcodebuild -version + xcrun --show-sdk-build-version + swift --version + xcodebuild -scheme llm-tool + xcodebuild -scheme mnist-tool + +workflows: + build_and_test: + when: + and: + - matches: + pattern: "^(?!pull/)[-\\w]+$" + value: << pipeline.git.branch >> + - not: << pipeline.parameters.nightly_build >> + - not: << pipeline.parameters.weekly_build >> + jobs: + - mac_build_and_test + + prb: + when: + matches: + pattern: "^pull/\\d+(/head)?$" + value: << pipeline.git.branch >> + jobs: + - hold: + type: approval + - apple/authenticate: + context: pr-approval + - mac_build_and_test: + requires: [ hold ] diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..330d167 --- /dev/null +++ b/.gitignore @@ -0,0 +1,90 @@ +# Xcode +# +# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore + +## User settings +xcuserdata/ + +## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9) +*.xcscmblueprint +*.xccheckout + +## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4) +build/ +DerivedData/ +*.moved-aside +*.pbxuser +!default.pbxuser +*.mode1v3 +!default.mode1v3 +*.mode2v3 +!default.mode2v3 +*.perspectivev3 +!default.perspectivev3 + +## Obj-C/Swift specific +*.hmap + +## App packaging +*.ipa +*.dSYM.zip +*.dSYM + +## Playgrounds +timeline.xctimeline +playground.xcworkspace + +# Swift Package Manager +# +# Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. +# Packages/ +# Package.pins +# Package.resolved +# *.xcodeproj +# +# Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata +# hence it is not needed unless you have added a package configuration file to your project +# .swiftpm + +.build/ + +# CocoaPods +# +# We recommend against adding the Pods directory to your .gitignore. However +# you should judge for yourself, the pros and cons are mentioned at: +# https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control +# +# Pods/ +# +# Add this line if you want to avoid checking in source code from the Xcode workspace +# *.xcworkspace + +# Carthage +# +# Add this line if you want to avoid checking in source code from Carthage dependencies. +# Carthage/Checkouts + +Carthage/Build/ + +# Accio dependency management +Dependencies/ +.accio/ + +# fastlane +# +# It is recommended to not store the screenshots in the git repo. +# Instead, use fastlane to re-generate the screenshots whenever they are needed. +# For more information about the recommended setup visit: +# https://docs.fastlane.tools/best-practices/source-control/#source-control + +fastlane/report.xml +fastlane/Preview.html +fastlane/screenshots/**/*.png +fastlane/test_output + +# Code Injection +# +# After new code Injection tools there's a generated folder /iOSInjectionProject +# https://github.com/johnno1962/injectionforxcode + +iOSInjectionProject/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c12932f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://github.com/slessans/pre-commit-swift-format + rev: "" + hooks: + - id: swift-format + args: ["--configuration", ".swift-format"] diff --git a/.swift-format b/.swift-format new file mode 100644 index 0000000..8892e9f --- /dev/null +++ b/.swift-format @@ -0,0 +1,7 @@ +{ + "version": 1, + "indentation": { + "spaces": 4 + }, + "spacesAroundRangeFormationOperators": true, +} diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift new file mode 100644 index 0000000..bae4f3f --- /dev/null +++ b/Libraries/LLM/Configuration.swift @@ -0,0 +1,77 @@ +// Copyright © 2024 Apple Inc. + +import Foundation + +public enum StringOrNumber: Codable, Equatable { + case string(String) + case float(Float) + + public init(from decoder: Decoder) throws { + let values = try decoder.singleValueContainer() + + if let v = try? values.decode(Float.self) { + self = .float(v) + } else { + let v = try values.decode(String.self) + self = .string(v) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .string(let v): try container.encode(v) + case .float(let v): try container.encode(v) + } + } +} + +public enum ModelType: String, Codable { + case mistral + case llama + case phi + case gemma + + func createModel(configuration: URL) throws -> LLMModel { + switch self { + case .mistral, .llama: + let configuration = try JSONDecoder().decode( + LlamaConfiguration.self, from: Data(contentsOf: configuration)) + return LlamaModel(configuration) + case .phi: + let configuration = try JSONDecoder().decode( + PhiConfiguration.self, from: Data(contentsOf: configuration)) + return PhiModel(configuration) + case .gemma: + let configuration = try JSONDecoder().decode( + GemmaConfiguration.self, from: Data(contentsOf: configuration)) + return GemmaModel(configuration) + } + } +} + +public struct BaseConfiguration: Codable { + let modelType: ModelType + + public struct Quantization: Codable { + public init(groupSize: Int, bits: Int) { + self.groupSize = groupSize + self.bits = bits + } + + let groupSize: Int + let bits: Int + + enum CodingKeys: String, CodingKey { + case groupSize = "group_size" + case bits = "bits" + } + } + + var quantization: Quantization? + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case quantization + } +} diff --git a/Libraries/LLM/Gemma.swift b/Libraries/LLM/Gemma.swift new file mode 100644 index 0000000..b934629 --- /dev/null +++ b/Libraries/LLM/Gemma.swift @@ -0,0 +1,273 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXNN + +// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma.py + +// specialized norm for gemma +private class RMSNorm: Module, UnaryLayer { + + let weight: MLXArray + let eps: Float + + public init(dimensions: Int, eps: Float = 1e-5) { + self.weight = MLXArray.ones([dimensions]) + self.eps = eps + super.init() + } + + func norm(_ x: MLXArray) -> MLXArray { + let S = 1.0 / sqrt(Float(x.dim(-1))) + + let n = (x * S).square().sum(axis: -1, keepDims: true) + return rsqrt(n + eps) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + let output = norm(x.asType(Float.self)).asType(x.dtype) + return (1 + weight) * output + } +} + +private class Attention: Module { + + let args: GemmaConfiguration + let repeats: Int + let scale: Float + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + let rope: RoPE + + public init(_ args: GemmaConfiguration) { + self.args = args + + let dim = args.hiddenSize + let heads = args.attentionHeads + let kvHeads = args.kvHeads + + self.repeats = heads / kvHeads + + let headDim = args.headDimensions + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + self.rope = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + if repeats > 1 { + keys = MLXArray.repeat(keys, count: repeats, axis: 1) + values = MLXArray.repeat(values, count: repeats, axis: 1) + } + + if let (keyCache, valueCache) = cache { + queries = rope(queries, offset: keyCache.dim(2)) + keys = rope(keys, offset: keyCache.dim(2)) + keys = concatenated([keyCache, keys], axis: 2) + values = concatenated([valueCache, values], axis: 2) + } else { + queries = rope(queries) + keys = rope(keys) + } + + var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) + if let mask { + scores = scores + mask + } + + scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) + + let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1) + + return (wo(output), (keys, values)) + } +} + +private class MLP: Module, UnaryLayer { + + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(gelu(gate(x)) * up(x)) + } +} + +private class TransformerBlock: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: GemmaConfiguration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return (out, cache) + } +} + +public class GemmaModelInner: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [TransformerBlock] + fileprivate let norm: RMSNorm + + let hiddenScale: Float + + public init(_ args: GemmaConfiguration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.hiddenScale = pow(Float(args.hiddenSize), 0.5) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + TransformerBlock(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var h = embedTokens(inputs) + h = h * hiddenScale + + var mask: MLXArray? = nil + if h.dim(1) > 1 { + mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1)) + mask = mask?.asType(h.dtype) + } + + var newCache = [(MLXArray, MLXArray)]() + + for (i, layer) in layers.enumerated() { + var cacheUpdate: (MLXArray, MLXArray) + (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i]) + newCache.append(cacheUpdate) + } + + return (norm(h), newCache) + } +} + +public class GemmaModel: Module, LLMModel { + + let model: GemmaModelInner + + public init(_ args: GemmaConfiguration) { + self.model = GemmaModelInner(args) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var (out, cache) = model(inputs, cache: cache) + out = matmul(out, model.embedTokens.weight.T) + return (out, cache) + } +} + +public struct GemmaConfiguration: Codable { + + var hiddenSize: Int + var hiddenLayers: Int + var intermediateSize: Int + var attentionHeads: Int + var headDimensions: Int + var rmsNormEps: Float + var vocabularySize: Int + var kvHeads: Int + var ropeTheta: Float = 10_000 + var ropeTraditional: Bool = false + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case headDimensions = "head_dim" + case rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case kvHeads = "num_key_value_heads" + case ropeTheta = "rope_theta" + case ropeTraditional = "rope_traditional" + } + + public init(from decoder: Decoder) throws { + // custom implementation to handle optional keys with required values + let container: KeyedDecodingContainer = try decoder.container( + keyedBy: CodingKeys.self) + + self.hiddenSize = try container.decode( + Int.self, forKey: CodingKeys.hiddenSize) + self.hiddenLayers = try container.decode( + Int.self, forKey: CodingKeys.hiddenLayers) + self.intermediateSize = try container.decode( + Int.self, forKey: CodingKeys.intermediateSize) + self.attentionHeads = try container.decode( + Int.self, forKey: CodingKeys.attentionHeads) + self.headDimensions = try container.decode( + Int.self, forKey: CodingKeys.headDimensions) + self.rmsNormEps = try container.decode( + Float.self, forKey: CodingKeys.rmsNormEps) + self.vocabularySize = try container.decode( + Int.self, forKey: CodingKeys.vocabularySize) + self.kvHeads = try container.decode(Int.self, forKey: CodingKeys.kvHeads) + self.ropeTheta = + try container.decodeIfPresent(Float.self, forKey: CodingKeys.ropeTheta) + ?? 10_000 + self.ropeTraditional = + try container.decodeIfPresent( + Bool.self, forKey: CodingKeys.ropeTraditional) ?? false + } +} diff --git a/Libraries/LLM/LLM.h b/Libraries/LLM/LLM.h new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/Libraries/LLM/LLM.h @@ -0,0 +1 @@ + diff --git a/Libraries/LLM/LLMModel.swift b/Libraries/LLM/LLMModel.swift new file mode 100644 index 0000000..7bb6f8e --- /dev/null +++ b/Libraries/LLM/LLMModel.swift @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXNN + +// Interface for all LLM Models +public protocol LLMModel: Module { + func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) +} diff --git a/Libraries/LLM/Llama.swift b/Libraries/LLM/Llama.swift new file mode 100644 index 0000000..ac64551 --- /dev/null +++ b/Libraries/LLM/Llama.swift @@ -0,0 +1,263 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXNN + +// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py + +private class Attention: Module { + + let args: LlamaConfiguration + let repeats: Int + let scale: Float + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + let rope: RoPE + + public init(_ args: LlamaConfiguration) { + self.args = args + + let dim = args.hiddenSize + let heads = args.attentionHeads + let kvHeads = args.kvHeads + + self.repeats = heads / kvHeads + + let headDim = args.hiddenSize / heads + self.scale = pow(Float(headDim), -0.5) + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + let ropeScale: Float + if let ropeScaling = args.ropeScaling, ropeScaling["type"] == .string("linear"), + let factor = ropeScaling["factor"] + { + switch factor { + case .string: + fatalError("ropeScaling.factor must be a float") + case .float(let v): + ropeScale = 1 / v + } + } else { + ropeScale = 1 + } + + self.rope = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta, + scale: ropeScale) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + if repeats > 1 { + keys = MLXArray.repeat(keys, count: repeats, axis: 1) + values = MLXArray.repeat(values, count: repeats, axis: 1) + } + + if let (keyCache, valueCache) = cache { + queries = rope(queries, offset: keyCache.dim(2)) + keys = rope(keys, offset: keyCache.dim(2)) + keys = concatenated([keyCache, keys], axis: 2) + values = concatenated([valueCache, values], axis: 2) + } else { + queries = rope(queries) + keys = rope(keys) + } + + var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) + if let mask { + scores = scores + mask + } + + scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) + + let output = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1) + + return (wo(output), (keys, values)) + } +} + +private class MLP: Module, UnaryLayer { + + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(silu(gate(x)) * up(x)) + } +} + +private class TransformerBlock: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: LlamaConfiguration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + r + r = mlp(postAttentionLayerNorm(h)) + let out = h + r + return (out, cache) + } +} + +public class LlamaModelInner: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [TransformerBlock] + let norm: RMSNorm + + public init(_ args: LlamaConfiguration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + TransformerBlock(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var h = embedTokens(inputs) + + var mask: MLXArray? = nil + if h.dim(1) > 1 { + mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1)) + mask = mask?.asType(h.dtype) + } + + var newCache = [(MLXArray, MLXArray)]() + + for (i, layer) in layers.enumerated() { + var cacheUpdate: (MLXArray, MLXArray) + (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i]) + newCache.append(cacheUpdate) + } + + return (norm(h), newCache) + } +} + +public class LlamaModel: Module, LLMModel { + + let model: LlamaModelInner + + @ModuleInfo(key: "lm_head") var lmHead: Linear + + public init(_ args: LlamaConfiguration) { + self.model = LlamaModelInner(args) + self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + let (out, cache) = model(inputs, cache: cache) + return (lmHead(out), cache) + } +} + +public struct LlamaConfiguration: Codable { + + var hiddenSize: Int + var hiddenLayers: Int + var intermediateSize: Int + var attentionHeads: Int + var rmsNormEps: Float + var vocabularySize: Int + var kvHeads: Int + var ropeTheta: Float = 10_000 + var ropeTraditional: Bool = false + var ropeScaling: [String: StringOrNumber]? = nil + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case kvHeads = "num_key_value_heads" + case ropeTheta = "rope_theta" + case ropeTraditional = "rope_traditional" + case ropeScaling = "rope_scaling" + } + + public init(from decoder: Decoder) throws { + // custom implementation to handle optional keys with required values + let container: KeyedDecodingContainer = + try decoder.container( + keyedBy: LlamaConfiguration.CodingKeys.self) + + self.hiddenSize = try container.decode( + Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenSize) + self.hiddenLayers = try container.decode( + Int.self, forKey: LlamaConfiguration.CodingKeys.hiddenLayers) + self.intermediateSize = try container.decode( + Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize) + self.attentionHeads = try container.decode( + Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads) + self.rmsNormEps = try container.decode( + Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps) + self.vocabularySize = try container.decode( + Int.self, forKey: LlamaConfiguration.CodingKeys.vocabularySize) + self.kvHeads = try container.decode(Int.self, forKey: LlamaConfiguration.CodingKeys.kvHeads) + self.ropeTheta = + try container.decodeIfPresent( + Float.self, forKey: LlamaConfiguration.CodingKeys.ropeTheta) + ?? 10_000 + self.ropeTraditional = + try container.decodeIfPresent( + Bool.self, forKey: LlamaConfiguration.CodingKeys.ropeTraditional) ?? false + self.ropeScaling = try container.decodeIfPresent( + [String: StringOrNumber].self, forKey: LlamaConfiguration.CodingKeys.ropeScaling) + + } +} diff --git a/Libraries/LLM/Phi.swift b/Libraries/LLM/Phi.swift new file mode 100644 index 0000000..1f7d6b5 --- /dev/null +++ b/Libraries/LLM/Phi.swift @@ -0,0 +1,302 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXNN + +// https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phi.py + +// TODO: remove once open classes are in + +public class MLXLayerNorm: Module, UnaryLayer { + + let dimensions: Int + let eps: Float + + let weight: MLXArray? + let bias: MLXArray? + + /// Applies layer normalization [1] on the inputs. + /// + /// See [LayerNorm python docs](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.LayerNorm.html) for more information. + /// + /// ### References + /// 1. [https://arxiv.org/abs/1607.06450](https://arxiv.org/abs/1607.06450) + /// + /// - Parameters: + /// - dimensions: number of features in the input + /// - eps: value added to the denominator for numerical stability + /// - affine: if `true` adds a trainable `weight` and `bias` + public init(dimensions: Int, eps: Float = 1e-5, affine: Bool = true) { + self.dimensions = dimensions + self.eps = eps + + if affine { + self.weight = MLXArray.ones([dimensions]) + self.bias = MLXArray.zeros([dimensions]) + } else { + self.weight = nil + self.bias = nil + } + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + let means = mean(x, axis: -1, keepDims: true) + let variance = variance(x, axis: -1, keepDims: true) + let x = (x - means) * rsqrt(variance + eps) + + if let weight, let bias { + return weight * x + bias + } else { + return x + } + } +} + +private class LayerNorm: MLXLayerNorm { + override func callAsFunction(_ x: MLXArray) -> MLXArray { + super.callAsFunction(x.asType(Float.self)).asType(x.dtype) + } +} + +private class PhiAttention: Module { + + let args: PhiConfiguration + let heads: Int + let headDim: Int + let repeats: Int + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "dense") var dense: Linear + + let rope: RoPE + + public init(_ args: PhiConfiguration) { + self.args = args + + let hiddenSize = args.hiddenSize + self.heads = args.attentionHeads + self.headDim = args.hiddenSize / heads + let kvHeads = args.kvHeads + self.repeats = heads / kvHeads + + if headDim * heads != hiddenSize { + fatalError("hidden_size must be divisible by num_heads") + } + + self._wq.wrappedValue = Linear(hiddenSize, heads * headDim, bias: true) + self._wk.wrappedValue = Linear(hiddenSize, kvHeads * headDim, bias: true) + self._wv.wrappedValue = Linear(hiddenSize, kvHeads * headDim, bias: true) + self._dense.wrappedValue = Linear(heads * headDim, hiddenSize, bias: true) + + self.rope = RoPE( + dimensions: Int(args.partialRotaryFactor * Float(headDim)), traditional: false, + base: args.ropeTheta) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, heads, headDim).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, headDim).transposed(0, 2, 1, 3) + + if repeats > 1 { + keys = MLXArray.repeat(keys, count: repeats, axis: 1) + values = MLXArray.repeat(values, count: repeats, axis: 1) + } + + // Add RoPE to the queries and keys and combine them with the cache + if let (keyCache, valueCache) = cache { + queries = rope(queries, offset: keyCache.dim(2)) + keys = rope(keys, offset: keyCache.dim(2)) + keys = concatenated([keyCache, keys], axis: 2) + values = concatenated([valueCache, values], axis: 2) + } else { + queries = rope(queries) + keys = rope(keys) + } + + queries = queries.asType(Float.self) + keys = keys.asType(Float.self) + + // Finally perform the attention computation + let scale = sqrt(1 / Float(queries.dim(-1))) + var scores = (queries * scale).matmul(keys.transposed(0, 1, 3, 2)) + if let mask { + scores = scores + mask + } + + scores = softMax(scores, axis: -1).asType(values.dtype) + let valuesHat = matmul(scores, values).transposed(0, 2, 1, 3).reshaped(B, L, -1) + + return (dense(valuesHat), (keys, values)) + } +} + +private class PhiMLP: Module, UnaryLayer { + + @ModuleInfo var fc1: Linear + @ModuleInfo var fc2: Linear + @ModuleInfo var act: GELU + + public init(_ config: PhiConfiguration) { + self.fc1 = Linear(config.hiddenSize, config.intermediateSize) + self.fc2 = Linear(config.intermediateSize, config.hiddenSize) + self.act = GELU(approximation: .precise) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + fc2(act(fc1(x))) + } +} + +private class PhiDecoderLayer: Module { + + @ModuleInfo(key: "self_attn") var selfAttention: PhiAttention + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: LayerNorm + var mlp: PhiMLP + + public init(_ config: PhiConfiguration) { + self._selfAttention.wrappedValue = PhiAttention(config) + self._inputLayerNorm.wrappedValue = LayerNorm( + dimensions: config.hiddenSize, eps: config.layerNormEps) + self.mlp = PhiMLP(config) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + let h = inputLayerNorm(x) + let (attentionH, cache) = selfAttention(h, mask: mask, cache: cache) + let ffH = mlp(h) + return (attentionH + ffH + x, cache) + } +} + +private class PhiModelInner: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + @ModuleInfo var layers: [PhiDecoderLayer] + @ModuleInfo(key: "final_layernorm") var finalLayerNorm: LayerNorm + + public init(_ args: PhiConfiguration) { + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + PhiDecoderLayer(args) + } + self._finalLayerNorm.wrappedValue = LayerNorm( + dimensions: args.hiddenSize, eps: args.layerNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: [(MLXArray, MLXArray)]? = nil + ) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var x = embedTokens(x) + + var newCache = [(MLXArray, MLXArray)]() + + for (i, layer) in layers.enumerated() { + var cacheUpdate: (MLXArray, MLXArray) + (x, cacheUpdate) = layer(x, mask: mask, cache: cache?[i]) + newCache.append(cacheUpdate) + } + + return (finalLayerNorm(x), newCache) + } +} + +public class PhiModel: Module, LLMModel { + + fileprivate let model: PhiModelInner + + @ModuleInfo(key: "lm_head") var lmHead: Linear + + public init(_ args: PhiConfiguration) { + self.model = PhiModelInner(args) + self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true) + } + + public func callAsFunction(_ x: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var mask: MLXArray? = nil + if x.dim(1) > 1 { + mask = MultiHeadAttention.createAdditiveCausalMask(x.dim(1)) + mask = mask?.asType(x.dtype) + } + + let (y, cache) = model(x, mask: mask, cache: cache) + return (lmHead(y), cache) + } +} + +public struct PhiConfiguration: Codable { + var maxPositionalEmbeddings = 2048 + var vocabularySize = 51200 + var hiddenSize = 2560 + var attentionHeads = 32 + var hiddenLayers = 32 + var kvHeads = 32 + var partialRotaryFactor: Float = 0.4 + var intermediateSize = 10240 + var layerNormEps: Float = 1e-5 + var ropeTheta: Float = 10_000 + + enum CodingKeys: String, CodingKey { + case maxPositionalEmbeddings = "max_position_embeddings" + case vocabularySize = "vocab_size" + case hiddenSize = "hidden_size" + case attentionHeads = "num_attention_heads" + case hiddenLayers = "num_hidden_layers" + case kvHeads = "num_key_value_heads" + case partialRotaryFactor = "partial_rotary_factor" + case intermediateSize = "intermediate_size" + case layerNormEps = "layer_norm_eps" + case ropeTheta = "rope_theta" + } + + public init(from decoder: Decoder) throws { + let container: KeyedDecodingContainer = try decoder.container( + keyedBy: PhiConfiguration.CodingKeys.self) + + self.maxPositionalEmbeddings = try container.decode( + Int.self, forKey: PhiConfiguration.CodingKeys.maxPositionalEmbeddings) + self.vocabularySize = try container.decode( + Int.self, forKey: PhiConfiguration.CodingKeys.vocabularySize) + self.hiddenSize = try container.decode( + Int.self, forKey: PhiConfiguration.CodingKeys.hiddenSize) + self.attentionHeads = try container.decode( + Int.self, forKey: PhiConfiguration.CodingKeys.attentionHeads) + self.hiddenLayers = try container.decode( + Int.self, forKey: PhiConfiguration.CodingKeys.hiddenLayers) + self.kvHeads = + try container.decodeIfPresent(Int.self, forKey: PhiConfiguration.CodingKeys.kvHeads) + ?? attentionHeads + self.partialRotaryFactor = try container.decode( + Float.self, forKey: PhiConfiguration.CodingKeys.partialRotaryFactor) + self.intermediateSize = try container.decode( + Int.self, forKey: PhiConfiguration.CodingKeys.intermediateSize) + self.layerNormEps = try container.decode( + Float.self, forKey: PhiConfiguration.CodingKeys.layerNormEps) + self.ropeTheta = + try container.decodeIfPresent(Float.self, forKey: PhiConfiguration.CodingKeys.ropeTheta) + ?? 10_000 + + } +} diff --git a/Libraries/LLM/README.md b/Libraries/LLM/README.md new file mode 100644 index 0000000..9ca29b3 --- /dev/null +++ b/Libraries/LLM/README.md @@ -0,0 +1,11 @@ +# Llama + +This is a port of the llama model from: + +- https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py + +You can use this to load models from huggingface, e.g.: + +- https://huggingface.co/mlx-community/Mistral-7B-v0.1-hf-4bit-mlx + +See [llm-tool](../../Tools/llm-tool) diff --git a/Libraries/LLM/Util.swift b/Libraries/LLM/Util.swift new file mode 100644 index 0000000..c579179 --- /dev/null +++ b/Libraries/LLM/Util.swift @@ -0,0 +1,110 @@ +// Copyright © 2024 Apple Inc. + +import AsyncAlgorithms +import Foundation +import Hub +import MLX +import MLXNN +import MLXRandom +import Tokenizers + +/// Load and return the model and tokenizer +public func load( + hub: HubApi = HubApi(), name: String, progressHandler: @escaping (Progress) -> Void = { _ in } +) async throws -> (LLMModel, Tokenizer) { + // note: this doesn't have a way to pass the HubApi + let tokenizer = try await AutoTokenizer.from(pretrained: name) + + // download the model weights and config + let repo = Hub.Repo(id: name) + let modelFiles = ["config.json", "weights.00.safetensors"] + let modelDirectory = try await hub.snapshot( + from: repo, matching: modelFiles, progressHandler: progressHandler) + + // create the model (no weights loaded) + let configurationURL = modelDirectory.appending(component: "config.json") + let baseConfig = try JSONDecoder().decode( + BaseConfiguration.self, from: Data(contentsOf: configurationURL)) + + let model = try baseConfig.modelType.createModel(configuration: configurationURL) + + // set up the model + if let quantization = baseConfig.quantization { + QuantizedLinear.quantize( + model: model, groupSize: quantization.groupSize, bits: quantization.bits) + } + + // apply the loaded weights + let weights = try loadArrays(url: modelDirectory.appending(component: "weights.00.safetensors")) + let parameters = ModuleParameters.unflattened(weights) + try model.update(parameters: parameters, verify: [.all]) + eval(model.parameters()) + + return (model, tokenizer) +} + +private func sample(logits: MLXArray, temp: Float) -> MLXArray { + if temp == 0 { + return argMax(logits, axis: -1) + } else { + return categorical(logits * (1 / temp)) + } +} + +/// Synchronous generator of tokens. +/// +/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py +public struct TokenIterator: Sequence, IteratorProtocol { + let model: LLMModel + let temp: Float + + var y: MLXArray + var cache: [(MLXArray, MLXArray)] + + var first = true + + public init(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) { + self.model = model + self.temp = temp + self.y = prompt + self.cache = [] + } + + mutating public func next() -> MLXArray? { + var logits: MLXArray + (logits, cache) = model(expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache) + y = sample(logits: logits[-1, axis: 1], temp: temp) + + return y + } +} + +/// Async generator of tokens. +/// +/// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py. +/// +/// Note that because MLXArray is not thread safe this eval's the result and sends the TokenId back +/// to the caller. +public func generate(prompt: MLXArray, model: LLMModel, temp: Float = 0.0) -> ( + Task, AsyncBufferSequence> +) { + let channel = AsyncChannel() + let buffer = channel.buffer(policy: .bounded(10)) + + let task = Task { + var y = prompt + var cache = [(MLXArray, MLXArray)]() + + while !Task.isCancelled { + var logits: MLXArray + (logits, cache) = model( + expandedDimensions(y, axis: 0), cache: cache.isEmpty ? nil : cache) + y = sample(logits: logits[-1, axis: 1], temp: temp) + eval(y) + + await channel.send(y.item(Int.self)) + } + } + + return (task, buffer) +} diff --git a/Libraries/MNIST/Files.swift b/Libraries/MNIST/Files.swift new file mode 100644 index 0000000..84957c2 --- /dev/null +++ b/Libraries/MNIST/Files.swift @@ -0,0 +1,102 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import Gzip +import MLX + +// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/mnist.py + +public enum Use: String, Hashable { + case test + case training +} + +public enum DataKind: String, Hashable { + case images + case labels +} + +public struct FileKind: Hashable, CustomStringConvertible { + let use: Use + let data: DataKind + + public init(_ use: Use, _ data: DataKind) { + self.use = use + self.data = data + } + + public var description: String { + "\(use.rawValue)-\(data.rawValue)" + } +} + +struct LoadInfo { + let name: String + let offset: Int + let convert: (MLXArray) -> MLXArray +} + +let baseURL = URL(string: "http://yann.lecun.com/exdb/mnist/")! + +let files = [ + FileKind(.training, .images): LoadInfo( + name: "train-images-idx3-ubyte.gz", + offset: 16, + convert: { + $0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0 + }), + FileKind(.test, .images): LoadInfo( + name: "t10k-images-idx3-ubyte.gz", + offset: 16, + convert: { + $0.reshaped([-1, 28 * 28]).asType(.float32) / 255.0 + }), + FileKind(.training, .labels): LoadInfo( + name: "train-labels-idx1-ubyte.gz", + offset: 8, + convert: { + $0.asType(.uint32) + }), + FileKind(.test, .labels): LoadInfo( + name: "t10k-labels-idx1-ubyte.gz", + offset: 8, + convert: { + $0.asType(.uint32) + }), +] + +public func download(into: URL) async throws { + for (_, info) in files { + let fileURL = into.appending(component: info.name) + if !FileManager.default.fileExists(atPath: fileURL.path()) { + print("Download: \(info.name)") + let url = baseURL.appending(component: info.name) + let (data, response) = try await URLSession.shared.data(from: url) + + guard let httpResponse = response as? HTTPURLResponse else { + fatalError("Unable to download \(url), not an http response: \(response)") + } + guard httpResponse.statusCode == 200 else { + fatalError("Unable to download \(url): \(httpResponse)") + } + + try data.write(to: fileURL) + } + } +} + +public func load(from: URL) throws -> [FileKind: MLXArray] { + var result = [FileKind: MLXArray]() + + for (key, info) in files { + let fileURL = from.appending(component: info.name) + let data = try Data(contentsOf: fileURL).gunzipped() + + let array = MLXArray( + data.dropFirst(info.offset), [data.count - info.offset], type: UInt8.self) + + result[key] = info.convert(array) + } + + return result +} diff --git a/Libraries/MNIST/MNIST.h b/Libraries/MNIST/MNIST.h new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/Libraries/MNIST/MNIST.h @@ -0,0 +1 @@ + diff --git a/Libraries/MNIST/MNIST.swift b/Libraries/MNIST/MNIST.swift new file mode 100644 index 0000000..78d912b --- /dev/null +++ b/Libraries/MNIST/MNIST.swift @@ -0,0 +1,73 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXNN + +// based on https://github.com/ml-explore/mlx-examples/blob/main/mnist/main.py + +public class MLP: Module, UnaryLayer { + + @ModuleInfo var layers: [Linear] + + public init(layers: Int, inputDimensions: Int, hiddenDimensions: Int, outputDimensions: Int) { + let layerSizes = + [inputDimensions] + Array(repeating: hiddenDimensions, count: layers) + [ + outputDimensions + ] + + self.layers = zip(layerSizes.dropLast(), layerSizes.dropFirst()) + .map { + Linear($0, $1) + } + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + var x = x + for l in layers.dropLast() { + x = relu(l(x)) + } + return layers.last!(x) + } +} + +public func loss(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray { + crossEntropy(logits: model(x), targets: y, reduction: .mean) +} + +public func eval(model: MLP, x: MLXArray, y: MLXArray) -> MLXArray { + mean(argMax(model(x), axis: 1) .== y) +} + +private struct BatchSequence: Sequence, IteratorProtocol { + + let batchSize: Int + let x: MLXArray + let y: MLXArray + + let indexes: MLXArray + var index = 0 + + init(batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator) + { + self.batchSize = batchSize + self.x = x + self.y = y + self.indexes = MLXArray(Array(0 ..< y.size).shuffled(using: &generator)) + } + + mutating func next() -> (MLXArray, MLXArray)? { + guard index < y.size else { return nil } + + let range = index ..< Swift.min(index + batchSize, y.size) + index += batchSize + let ids = indexes[range] + return (x[ids], y[ids]) + } +} + +public func iterateBatches( + batchSize: Int, x: MLXArray, y: MLXArray, using generator: inout any RandomNumberGenerator +) -> some Sequence<(MLXArray, MLXArray)> { + BatchSequence(batchSize: batchSize, x: x, y: y, using: &generator) +} diff --git a/Libraries/MNIST/README.md b/Libraries/MNIST/README.md new file mode 100644 index 0000000..5d7918d --- /dev/null +++ b/Libraries/MNIST/README.md @@ -0,0 +1,13 @@ +# MNIST + +This is a port of the MNIST model and training code from: + +- https://github.com/ml-explore/mlx-examples/blob/main/mnist + +It provides code to: + +- download the test/train data +- provides the MNIST model (MLP) +- some functions to shuffle and batch the data + +See [mnist-tool](../../Tools/mnist-tool) for an example of how to run this. The training loop also lives there. diff --git a/Libraries/MNIST/Random.swift b/Libraries/MNIST/Random.swift new file mode 100644 index 0000000..66fb7e7 --- /dev/null +++ b/Libraries/MNIST/Random.swift @@ -0,0 +1,30 @@ +// Copyright © 2024 Apple Inc. + +import Foundation + +// From https://github.com/apple/swift/blob/cb0fb1ea051631219c0b944b84c78571448d58c2/benchmark/utils/TestsUtils.swift#L254 +// +// This is just a seedable RandomNumberGenerator for shuffle() + +// This is a fixed-increment version of Java 8's SplittableRandom generator. +// It is a very fast generator passing BigCrush, with 64 bits of state. +// See http://dx.doi.org/10.1145/2714064.2660195 and +// http://docs.oracle.com/javase/8/docs/api/java/util/SplittableRandom.html +// +// Derived from public domain C implementation by Sebastiano Vigna +// See http://xoshiro.di.unimi.it/splitmix64.c +public struct SplitMix64: RandomNumberGenerator { + private var state: UInt64 + + public init(seed: UInt64) { + self.state = seed + } + + public mutating func next() -> UInt64 { + self.state &+= 0x9e37_79b9_7f4a_7c15 + var z: UInt64 = self.state + z = (z ^ (z &>> 30)) &* 0xbf58_476d_1ce4_e5b9 + z = (z ^ (z &>> 27)) &* 0x94d0_49bb_1331_11eb + return z ^ (z &>> 31) + } +} diff --git a/README.md b/README.md new file mode 100644 index 0000000..e803ba1 --- /dev/null +++ b/README.md @@ -0,0 +1,22 @@ +# mlx-examples-swift + +Example mlx-swift programs. + +## LinearModelTraining + +A simple linear model and a training loop. + +- [README](Tools/LinearModelTraining/README.md) + +## llm-tool + +A command line tool for generating text using a Llama / Mistral model: + +- [README](Tools/llm-tool/README.md) + +## mnist-tool + +A command line tool for training an MNIST (MLP) model: + +- [README](Tools/mnist-tool/README.md) + diff --git a/Tools/LinearModelTraining/LinearModelTraining.swift b/Tools/LinearModelTraining/LinearModelTraining.swift new file mode 100644 index 0000000..cf25724 --- /dev/null +++ b/Tools/LinearModelTraining/LinearModelTraining.swift @@ -0,0 +1,113 @@ +// Copyright © 2024 Apple Inc. + +import ArgumentParser +import Foundation +import MLX +import MLXNN +import MLXOptimizers +import MLXRandom + +extension MLX.DeviceType: ExpressibleByArgument { + public init?(argument: String) { + self.init(rawValue: argument) + } +} + +@main +struct Train: AsyncParsableCommand { + + @Option var epochs = 20 + @Option var batchSize = 8 + + @Option var m: Float = 0.25 + @Option var b: Float = 7 + + @Flag var compile = false + + @Option var device = DeviceType.cpu + + func run() async throws { + Device.setDefault(device: Device(device)) + + // A very simple model that implements the equation + // for a linear function: y = mx + b. This can be trained + // to match data -- in this case an unknown (to the model) + // linear function. + // + // This is a nice example because most people know how + // linear functions work and we can see how the slope + // and intercept converge. + class LinearFunctionModel: Module, UnaryLayer { + let m = MLXRandom.uniform(low: -5.0, high: 5.0) + let b = MLXRandom.uniform(low: -5.0, high: 5.0) + + func callAsFunction(_ x: MLXArray) -> MLXArray { + m * x + b + } + } + + // measure the distance from the prediction (model(x)) and the + // ground truth (y). this gives feedback on how close the + // prediction is from matching the truth + func loss(model: LinearFunctionModel, x: MLXArray, y: MLXArray) -> MLXArray { + mseLoss(predictions: model(x), targets: y, reduction: .mean) + } + + let model = LinearFunctionModel() + eval(model.parameters()) + + let lg = valueAndGrad(model: model, loss) + + // the optimizer will use the gradients update the model parameters + let optimizer = SGD(learningRate: 1e-1) + + // the function to train our model against -- it doesn't have + // to be linear, but matching what the model models is easy + // to understand + func f(_ x: MLXArray) -> MLXArray { + // these are the target parameters + let m = self.m + let b = self.b + + // our actual function + return m * x + b + } + + func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray { + let (loss, grads) = lg(model, x, y) + optimizer.update(model: model, gradients: grads) + return loss + } + + let resolvedStep = + self.compile + ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step + + for _ in 0 ..< epochs { + // we expect that the parameters will approach the targets + print("target: b = \(b), m = \(m)") + print("parameters: \(model.parameters())") + + // generate random training data along with the ground truth. + // notice that the shape is [B, 1] where B is the batch + // dimension -- this allows us to train on several samples simultaneously + // + // note: a very large batch size will take longer to converge because + // the gradient will be representing too many samples down into + // a single float parameter. + let x = MLXRandom.uniform(low: -5.0, high: 5.0, [batchSize, 1]) + let y = f(x) + eval(x, y) + + // compute the loss and gradients. use the optimizer + // to adjust the parameters closer to the target + let loss = resolvedStep(x, y) + + eval(model, optimizer) + + // we should see this converge toward 0 + print("loss: \(loss)") + } + + } +} diff --git a/Tools/LinearModelTraining/README.md b/Tools/LinearModelTraining/README.md new file mode 100644 index 0000000..ed8d68b --- /dev/null +++ b/Tools/LinearModelTraining/README.md @@ -0,0 +1,14 @@ +# LinearModelTraining + +A command line tool that creates a Model that represents: + + f(x) = mx + b + +and trains it against an unknown linear function. Very +simple but illustrates: + +- a very simple model with parameters +- a loss function +- the gradient +- use of an optimizers +- the training loop diff --git a/Tools/Tutorial/Tutorial.swift b/Tools/Tutorial/Tutorial.swift new file mode 100644 index 0000000..1cd5ada --- /dev/null +++ b/Tools/Tutorial/Tutorial.swift @@ -0,0 +1,102 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX + +/// mlx-swift tutorial based on: +/// https://github.com/ml-explore/mlx/blob/main/examples/cpp/tutorial.cpp +@main +struct Tutorial { + + static func scalarBasics() { + // create a scalar array + let x = MLXArray(1.0) + + // the datatype is .float32 + let dtype = x.dtype + assert(dtype == .float32) + + // get the value + let s = x.item(Float.self) + assert(s == 1.0) + + // reading the value with a different type is a fatal error + // let i = x.item(Int.self) + + // scalars have a size of 1 + let size = x.size + assert(size == 1) + + // scalars have 0 dimensions + let ndim = x.ndim + assert(ndim == 0) + + // scalar shapes are empty arrays + let shape = x.shape + assert(shape == []) + } + + static func arrayBasics() { + // make a multidimensional array. + // + // Note: the argument is a [Double] array literal, which is not + // a supported type, but we can explicitly convert it to [Float] + // when we create the MLXArray. + let x = MLXArray(converting: [1.0, 2.0, 3.0, 4.0], [2, 2]) + + // mlx is row-major by default so the first row of this array + // is [1.0, 2.0] and the second row is [3.0, 4.0] + print(x[0]) + print(x[1]) + + // make an array of shape [2, 2] filled with ones + let y = MLXArray.ones([2, 2]) + + // pointwise add x and y + let z = x + y + + // mlx is lazy by default. At this point `z` only + // has a shape and a type but no actual data + assert(z.dtype == .float32) + assert(z.shape == [2, 2]) + + // To actually run the computation you must evaluate `z`. + // Under the hood, mlx records operations in a graph. + // The variable `z` is a node in the graph which points to its operation + // and inputs. When `eval` is called on an array (or arrays), the array and + // all of its dependencies are recursively evaluated to produce the result. + // Once an array is evaluated, it has data and is detached from its inputs. + + // Note: this is being called for demonstration purposes -- all reads + // ensure the array is evaluated. + z.eval() + + // this implicitly evaluates z before converting to a description + print(z) + } + + static func automaticDifferentiation() { + func fn(_ x: MLXArray) -> MLXArray { + x.square() + } + + let gradFn = grad(fn) + + let x = MLXArray(1.5) + let dfdx = gradFn(x) + print(dfdx) + + assert(dfdx.item() == Float(2 * 1.5)) + + let df2dx2 = grad(grad(fn))(x) + print(df2dx2) + + assert(df2dx2.item() == Float(2)) + } + + static func main() { + scalarBasics() + arrayBasics() + automaticDifferentiation() + } +} diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift new file mode 100644 index 0000000..74f23f5 --- /dev/null +++ b/Tools/llm-tool/LLMTool.swift @@ -0,0 +1,190 @@ +// Copyright © 2024 Apple Inc. + +import ArgumentParser +import Foundation +import LLM +import MLX +import MLXRandom + +struct LLMTool: AsyncParsableCommand { + static var configuration = CommandConfiguration( + abstract: "Command line tool for generating text using Llama models", + subcommands: [SyncGenerator.self, AsyncGenerator.self], + defaultSubcommand: SyncGenerator.self) +} + +@main +struct SyncGenerator: AsyncParsableCommand { + + static var configuration = CommandConfiguration( + commandName: "sync", + abstract: "Synchronous generator" + ) + + @Option(name: .long, help: "Name of the huggingface model") + var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx" + + @Option(name: .shortAndLong, help: "The message to be processed by the model") + var prompt = "compare swift and python" + + @Option(name: .shortAndLong, help: "Maximum number of tokens to generate") + var maxTokens = 100 + + @Option(name: .shortAndLong, help: "The sampling temperature") + var temperature: Float = 0.0 + + @Option(name: .long, help: "The PRNG seed") + var seed: UInt64 = 0 + + @MainActor + func run() async throws { + MLXRandom.seed(seed) + + let (model, tokenizer) = try await load(name: model) + + print("Starting generation ...") + print(prompt, terminator: "") + + var start = Date.timeIntervalSinceReferenceDate + var promptTime: TimeInterval = 0 + + let prompt = MLXArray(tokenizer.encode(text: prompt)) + + // collect the tokens and keep track of how much of the string + // we have printed already + var tokens = [Int]() + var printed = 0 + + for token in TokenIterator(prompt: prompt, model: model, temp: temperature) { + if tokens.isEmpty { + eval(token) + let now = Date.timeIntervalSinceReferenceDate + promptTime = now - start + start = now + } + + let t = token.item(Int.self) + if t == tokenizer.unknownTokenId { + break + } + tokens.append(t) + + // print any new parts of the string + let fullOutput = tokenizer.decode(tokens: tokens) + let emitLength = fullOutput.count - printed + let suffix = fullOutput.suffix(emitLength) + print(suffix, terminator: "") + fflush(stdout) + + printed = fullOutput.count + + if tokens.count == maxTokens { + break + } + } + + print() + print("------") + let now = Date.timeIntervalSinceReferenceDate + let generateTime = now - start + + print( + """ + Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted()) + Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted()) + """) + } +} + +/// Example of an async generator. +/// +/// Note that all of the computation is done on another thread and TokenId (Int32) are sent +/// rather than MLXArray. +struct AsyncGenerator: AsyncParsableCommand { + + static var configuration = CommandConfiguration( + commandName: "async", + abstract: "async generator" + ) + + @Option(name: .long, help: "Name of the huggingface model") + var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx" + + @Option(name: .shortAndLong, help: "The message to be processed by the model") + var prompt = "compare swift and python" + + @Option(name: .shortAndLong, help: "Maximum number of tokens to generate") + var maxTokens = 100 + + @Option(name: .shortAndLong, help: "The sampling temperature") + var temperature: Float = 0.0 + + @Option(name: .long, help: "The PRNG seed") + var seed: UInt64 = 0 + + @MainActor + func run() async throws { + MLXRandom.seed(seed) + + let (model, tokenizer) = try await load(name: model) + + print("Starting generation ...") + print(prompt, terminator: "") + + var start = Date.timeIntervalSinceReferenceDate + var promptTime: TimeInterval = 0 + + let prompt = MLXArray(tokenizer.encode(text: prompt)) + + // collect the tokens and keep track of how much of the string + // we have printed already + var tokens = [Int]() + var printed = 0 + + let (task, channel) = generate(prompt: prompt, model: model, temp: temperature) + + for await token in channel { + if tokens.isEmpty { + let now = Date.timeIntervalSinceReferenceDate + promptTime = now - start + start = now + } + + if token == tokenizer.unknownTokenId { + break + } + tokens.append(token) + + // print any new parts of the string + let fullOutput = tokenizer.decode(tokens: tokens) + let emitLength = fullOutput.count - printed + let suffix = fullOutput.suffix(emitLength) + print(suffix, terminator: "") + fflush(stdout) + + printed = fullOutput.count + + if tokens.count == maxTokens { + break + } + } + + // tell the task to stop + task.cancel() + + print() + print("------") + let now = Date.timeIntervalSinceReferenceDate + let generateTime = now - start + + print( + """ + Prompt Tokens per second: \((Double(prompt.size) / promptTime).formatted()) + Generation tokens per second: \((Double(tokens.count - 1) / generateTime).formatted()) + """) + + // wait for the task to complete -- since it is running async, it might + // be in the middle of running the model + try? await Task.sleep(for: .milliseconds(500)) + } +} diff --git a/Tools/llm-tool/README.md b/Tools/llm-tool/README.md new file mode 100644 index 0000000..81b2151 --- /dev/null +++ b/Tools/llm-tool/README.md @@ -0,0 +1,38 @@ +# llm-tool + +See various READMEs: + +- [Llama](../../Libraries/Llama/README.md) + +### Building + +Build the `llm-tool` scheme in Xcode. + +### Running (Xcode) + +To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example: + +``` +--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx +--prompt "swift programming language" +--max-tokens 50 +``` + +Then cmd-r to run. + +> Note: you may be prompted for access to your Documents directory -- this is where +the huggingface HubApi stores the downloaded files. + +### Running (Command Line) + +`llm-tool` can also be run from the command line if built from Xcode, but +the `DYLD_FRAMEWORK_PATH` must be set so that the frameworks and bundles can be found: + +- [MLX troubleshooting](https://ml-explore.github.io/mlx-swift/MLX/documentation/mlx/troubleshooting) + +The easiest way to do this is drag the Products/llm-tool into Terminal to get the path: + +``` +DYLD_FRAMEWORK_PATH=~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug ~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug/llm-tool --prompt "swift programming language" +``` + diff --git a/Tools/mnist-tool/MNISTTool.swift b/Tools/mnist-tool/MNISTTool.swift new file mode 100644 index 0000000..68fd4b3 --- /dev/null +++ b/Tools/mnist-tool/MNISTTool.swift @@ -0,0 +1,108 @@ +// Copyright © 2024 Apple Inc. + +import ArgumentParser +import Foundation +import MLX +import MLXNN +import MLXOptimizers +import MLXRandom +import MNIST + +@main +struct MNISTTool: AsyncParsableCommand { + static var configuration = CommandConfiguration( + abstract: "Command line tool for training mnist models", + subcommands: [Train.self], + defaultSubcommand: Train.self) +} + +extension MLX.DeviceType: ExpressibleByArgument { + public init?(argument: String) { + self.init(rawValue: argument) + } +} + +struct Train: AsyncParsableCommand { + + @Option(name: .long, help: "Directory with the training data") + var data: String + + @Option(name: .long, help: "The PRNG seed") + var seed: UInt64 = 0 + + @Option var layers = 2 + @Option var hidden = 32 + @Option var batchSize = 256 + @Option var epochs = 20 + @Option var learningRate: Float = 1e-1 + + @Option var classes = 10 + + @Option var device = DeviceType.cpu + + @Flag var compile = false + + func run() async throws { + Device.setDefault(device: Device(device)) + + MLXRandom.seed(seed) + var generator: RandomNumberGenerator = SplitMix64(seed: seed) + + // load the data + let url = URL(filePath: data) + + try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true) + try await download(into: url) + + let data = try load(from: url) + + let trainImages = data[.init(.training, .images)]! + let trainLabels = data[.init(.training, .labels)]! + let testImages = data[.init(.test, .images)]! + let testLabels = data[.init(.test, .labels)]! + + // create the model + let model = MLP( + layers: layers, inputDimensions: trainImages.dim(-1), hiddenDimensions: hidden, + outputDimensions: classes) + eval(model.parameters()) + + let lg = valueAndGrad(model: model, loss) + let optimizer = SGD(learningRate: learningRate) + + func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray { + let (loss, grads) = lg(model, x, y) + optimizer.update(model: model, gradients: grads) + return loss + } + + let resolvedStep = + compile + ? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step + + for e in 0 ..< epochs { + let start = Date.timeIntervalSinceReferenceDate + + for (x, y) in iterateBatches( + batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator) + { + _ = resolvedStep(x, y) + + // eval the parameters so the next iteration is independent + eval(model, optimizer) + } + + let accuracy = eval(model: model, x: testImages, y: testLabels) + + let end = Date.timeIntervalSinceReferenceDate + + print( + """ + Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted()) + Time: \((end - start).formatted()) + + """ + ) + } + } +} diff --git a/Tools/mnist-tool/README.md b/Tools/mnist-tool/README.md new file mode 100644 index 0000000..3bd745a --- /dev/null +++ b/Tools/mnist-tool/README.md @@ -0,0 +1,36 @@ +# mnist-tool + +See other README: + +- [MNIST](../../Libraries/MNIST/README.md) + +### Building + +`mnist-tool` has no dependencies outside of the package dependencies +represented in xcode. + +When you run the tool it will download the test/train datasets and +store them in a specified directory (see run arguments -- default is /tmp). + +Simply build the project in xcode. + +### Running (Xcode) + +To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example: + +``` +--data /tmp +``` + +Then cmd-r to run. + +### Running (CommandLine) + +`mnist-tool` can also be run from the command line if built from Xcode, but +the `DYLD_FRAMEWORK_PATH` must be set so that the frameworks and bundles can be found: + +- [MLX troubleshooting](https://ml-explore.github.io/mlx-swift/MLX/documentation/mlx/troubleshooting) + +``` +DYLD_FRAMEWORK_PATH=~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug ~/Library/Developer/Xcode/DerivedData/mlx-examples-swift-ceuohnhzsownvsbbleukxoksddja/Build/Products/Debug/mnist-tool --data /tmp +``` diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj new file mode 100644 index 0000000..555ad92 --- /dev/null +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -0,0 +1,1716 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 56; + objects = { + +/* Begin PBXBuildFile section */ + C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; }; + C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; }; + C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; }; + C34E49102B69A92900FCB841 /* MNIST.h in Headers */ = {isa = PBXBuildFile; fileRef = C34E490F2B69A92900FCB841 /* MNIST.h */; settings = {ATTRIBUTES = (Public, ); }; }; + C34E49152B69C1E300FCB841 /* Files.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E49142B69C1E300FCB841 /* Files.swift */; }; + C34E491C2B69C43600FCB841 /* Gzip in Frameworks */ = {isa = PBXBuildFile; productRef = C34E491B2B69C43600FCB841 /* Gzip */; }; + C34E49242B6A026F00FCB841 /* MNISTTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E49232B6A026F00FCB841 /* MNISTTool.swift */; }; + C34E49292B6A028100FCB841 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C34E49282B6A028100FCB841 /* ArgumentParser */; }; + C34E492A2B6A028800FCB841 /* MNIST.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; }; + C34E492B2B6A028800FCB841 /* MNIST.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; + C382DE8A2B630889000F8F03 /* AsyncAlgorithms in Frameworks */ = {isa = PBXBuildFile; productRef = C382DE892B630889000F8F03 /* AsyncAlgorithms */; }; + C38935C82B869C7A0037B833 /* LLM.h in Headers */ = {isa = PBXBuildFile; fileRef = C38935C72B869C7A0037B833 /* LLM.h */; settings = {ATTRIBUTES = (Public, ); }; }; + C38935CC2B869C870037B833 /* Llama.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48EE2B696E6500FCB841 /* Llama.swift */; }; + C38935CD2B869C870037B833 /* Configuration.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48EF2B696E6500FCB841 /* Configuration.swift */; }; + C38935CE2B869C870037B833 /* Util.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48ED2B696E6500FCB841 /* Util.swift */; }; + C38935D02B869CC40037B833 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C38935CF2B869CC40037B833 /* MLX */; }; + C38935D22B869CC40037B833 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C38935D12B869CC40037B833 /* MLXNN */; }; + C38935D42B869CC40037B833 /* MLXRandom in Frameworks */ = {isa = PBXBuildFile; productRef = C38935D32B869CC40037B833 /* MLXRandom */; }; + C38935D62B869CC40037B833 /* Transformers in Frameworks */ = {isa = PBXBuildFile; productRef = C38935D52B869CC40037B833 /* Transformers */; }; + C38935D72B869CCD0037B833 /* LLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; }; + C38935D82B869CCD0037B833 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; + C38935DD2B869CEC0037B833 /* AsyncAlgorithms in Frameworks */ = {isa = PBXBuildFile; productRef = C38935DC2B869CEC0037B833 /* AsyncAlgorithms */; }; + C38935DF2B869DD00037B833 /* Phi.swift in Sources */ = {isa = PBXBuildFile; fileRef = C38935DE2B869DD00037B833 /* Phi.swift */; }; + C38935E12B869F420037B833 /* LLMModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = C38935E02B869F420037B833 /* LLMModel.swift */; }; + C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */ = {isa = PBXBuildFile; fileRef = C38935E22B86C0FE0037B833 /* Gemma.swift */; }; + C392737D2B606A1D00368D5D /* Tutorial.swift in Sources */ = {isa = PBXBuildFile; fileRef = C392737C2B606A1D00368D5D /* Tutorial.swift */; }; + C3932D572B6A060B00A81055 /* MNIST.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D562B6A060B00A81055 /* MNIST.swift */; }; + C3932D592B6A0BE400A81055 /* Random.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3932D582B6A0BE400A81055 /* Random.swift */; }; + C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C397C59B2B62C6D0004B084D /* ArgumentParser */; }; + C3FBCB212B8520B80007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB202B8520B80007E490 /* MLX */; }; + C3FBCB292B8520DA0007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB282B8520DA0007E490 /* MLX */; }; + C3FBCB2B2B8520DA0007E490 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB2A2B8520DA0007E490 /* MLXNN */; }; + C3FBCB2D2B8520E80007E490 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB2C2B8520E80007E490 /* MLXOptimizers */; }; + C3FBCB2F2B8520F20007E490 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB2E2B8520F20007E490 /* MLX */; }; + C3FBCB312B8520F20007E490 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB302B8520F20007E490 /* MLXNN */; }; + C3FBCB332B8520F20007E490 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB322B8520F20007E490 /* MLXOptimizers */; }; + C3FBCB352B8520F20007E490 /* MLXRandom in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB342B8520F20007E490 /* MLXRandom */; }; +/* End PBXBuildFile section */ + +/* Begin PBXContainerItemProxy section */ + C34E492C2B6A028800FCB841 /* PBXContainerItemProxy */ = { + isa = PBXContainerItemProxy; + containerPortal = C39273682B60697700368D5D /* Project object */; + proxyType = 1; + remoteGlobalIDString = C34E490C2B69A92900FCB841; + remoteInfo = MNIST; + }; + C38935D92B869CCD0037B833 /* PBXContainerItemProxy */ = { + isa = PBXContainerItemProxy; + containerPortal = C39273682B60697700368D5D /* Project object */; + proxyType = 1; + remoteGlobalIDString = C38935C42B869C7A0037B833; + remoteInfo = LLM; + }; +/* End PBXContainerItemProxy section */ + +/* Begin PBXCopyFilesBuildPhase section */ + C3288D712B6D9313009FF608 /* CopyFiles */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = /usr/share/man/man1/; + dstSubfolderSpec = 0; + files = ( + ); + runOnlyForDeploymentPostprocessing = 1; + }; + C34E491F2B6A026F00FCB841 /* CopyFiles */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = /usr/share/man/man1/; + dstSubfolderSpec = 0; + files = ( + ); + runOnlyForDeploymentPostprocessing = 1; + }; + C34E492E2B6A028800FCB841 /* Embed Frameworks */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = ""; + dstSubfolderSpec = 10; + files = ( + C34E492B2B6A028800FCB841 /* MNIST.framework in Embed Frameworks */, + ); + name = "Embed Frameworks"; + runOnlyForDeploymentPostprocessing = 0; + }; + C38935DB2B869CCE0037B833 /* Embed Frameworks */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = ""; + dstSubfolderSpec = 10; + files = ( + C38935D82B869CCD0037B833 /* LLM.framework in Embed Frameworks */, + ); + name = "Embed Frameworks"; + runOnlyForDeploymentPostprocessing = 0; + }; + C39273722B606A0A00368D5D /* CopyFiles */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = /usr/share/man/man1/; + dstSubfolderSpec = 0; + files = ( + ); + runOnlyForDeploymentPostprocessing = 1; + }; + C397C5892B62C6A9004B084D /* CopyFiles */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = /usr/share/man/man1/; + dstSubfolderSpec = 0; + files = ( + ); + runOnlyForDeploymentPostprocessing = 1; + }; +/* End PBXCopyFilesBuildPhase section */ + +/* Begin PBXFileReference section */ + C325DE3F2B648CDB00628871 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; + C3288D732B6D9313009FF608 /* LinearModelTraining */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = LinearModelTraining; sourceTree = BUILT_PRODUCTS_DIR; }; + C3288D752B6D9313009FF608 /* LinearModelTraining.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LinearModelTraining.swift; sourceTree = ""; }; + C3288D842B6D94BD009FF608 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; + C34E48ED2B696E6500FCB841 /* Util.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Util.swift; sourceTree = ""; }; + C34E48EE2B696E6500FCB841 /* Llama.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Llama.swift; sourceTree = ""; }; + C34E48EF2B696E6500FCB841 /* Configuration.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Configuration.swift; sourceTree = ""; }; + C34E48F42B696F0B00FCB841 /* LLMTool.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LLMTool.swift; sourceTree = ""; }; + C34E48F62B69832600FCB841 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; + C34E48F92B69930300FCB841 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; + C34E490D2B69A92900FCB841 /* MNIST.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = MNIST.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + C34E490F2B69A92900FCB841 /* MNIST.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = MNIST.h; sourceTree = ""; }; + C34E49142B69C1E300FCB841 /* Files.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Files.swift; sourceTree = ""; }; + C34E49212B6A026F00FCB841 /* mnist-tool */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "mnist-tool"; sourceTree = BUILT_PRODUCTS_DIR; }; + C34E49232B6A026F00FCB841 /* MNISTTool.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MNISTTool.swift; sourceTree = ""; }; + C38935C52B869C7A0037B833 /* LLM.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LLM.framework; sourceTree = BUILT_PRODUCTS_DIR; }; + C38935C72B869C7A0037B833 /* LLM.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LLM.h; sourceTree = ""; }; + C38935DE2B869DD00037B833 /* Phi.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Phi.swift; sourceTree = ""; }; + C38935E02B869F420037B833 /* LLMModel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LLMModel.swift; sourceTree = ""; }; + C38935E22B86C0FE0037B833 /* Gemma.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Gemma.swift; sourceTree = ""; }; + C39273742B606A0A00368D5D /* Tutorial */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = Tutorial; sourceTree = BUILT_PRODUCTS_DIR; }; + C392737C2B606A1D00368D5D /* Tutorial.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Tutorial.swift; sourceTree = ""; }; + C3932D562B6A060B00A81055 /* MNIST.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MNIST.swift; sourceTree = ""; }; + C3932D582B6A0BE400A81055 /* Random.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Random.swift; sourceTree = ""; }; + C397C58B2B62C6A9004B084D /* llm-tool */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "llm-tool"; sourceTree = BUILT_PRODUCTS_DIR; }; + C3C3240B2B6CA689007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; + C3C3240C2B6CA792007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + C3288D702B6D9313009FF608 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + C3FBCB332B8520F20007E490 /* MLXOptimizers in Frameworks */, + C3FBCB312B8520F20007E490 /* MLXNN in Frameworks */, + C3FBCB2F2B8520F20007E490 /* MLX in Frameworks */, + C3FBCB352B8520F20007E490 /* MLXRandom in Frameworks */, + C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C34E490A2B69A92900FCB841 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + C3FBCB2B2B8520DA0007E490 /* MLXNN in Frameworks */, + C3FBCB292B8520DA0007E490 /* MLX in Frameworks */, + C34E491C2B69C43600FCB841 /* Gzip in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C34E491E2B6A026F00FCB841 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + C3FBCB2D2B8520E80007E490 /* MLXOptimizers in Frameworks */, + C34E492A2B6A028800FCB841 /* MNIST.framework in Frameworks */, + C34E49292B6A028100FCB841 /* ArgumentParser in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C38935C22B869C7A0037B833 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + C38935D22B869CC40037B833 /* MLXNN in Frameworks */, + C38935D42B869CC40037B833 /* MLXRandom in Frameworks */, + C38935D62B869CC40037B833 /* Transformers in Frameworks */, + C38935D02B869CC40037B833 /* MLX in Frameworks */, + C38935DD2B869CEC0037B833 /* AsyncAlgorithms in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C39273712B606A0A00368D5D /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + C3FBCB212B8520B80007E490 /* MLX in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C397C5882B62C6A9004B084D /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + C397C59C2B62C6D0004B084D /* ArgumentParser in Frameworks */, + C38935D72B869CCD0037B833 /* LLM.framework in Frameworks */, + C382DE8A2B630889000F8F03 /* AsyncAlgorithms in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + C3288D742B6D9313009FF608 /* LinearModelTraining */ = { + isa = PBXGroup; + children = ( + C3288D752B6D9313009FF608 /* LinearModelTraining.swift */, + C3288D842B6D94BD009FF608 /* README.md */, + ); + path = LinearModelTraining; + sourceTree = ""; + }; + C34E48F32B696F0B00FCB841 /* llm-tool */ = { + isa = PBXGroup; + children = ( + C34E48F42B696F0B00FCB841 /* LLMTool.swift */, + C34E48F92B69930300FCB841 /* README.md */, + ); + path = "llm-tool"; + sourceTree = ""; + }; + C34E490E2B69A92900FCB841 /* MNIST */ = { + isa = PBXGroup; + children = ( + C34E490F2B69A92900FCB841 /* MNIST.h */, + C34E49142B69C1E300FCB841 /* Files.swift */, + C3932D562B6A060B00A81055 /* MNIST.swift */, + C3932D582B6A0BE400A81055 /* Random.swift */, + C3C3240C2B6CA792007D2D9A /* README.md */, + ); + path = MNIST; + sourceTree = ""; + }; + C34E49222B6A026F00FCB841 /* mnist-tool */ = { + isa = PBXGroup; + children = ( + C34E49232B6A026F00FCB841 /* MNISTTool.swift */, + C3C3240B2B6CA689007D2D9A /* README.md */, + ); + path = "mnist-tool"; + sourceTree = ""; + }; + C38935C62B869C7A0037B833 /* LLM */ = { + isa = PBXGroup; + children = ( + C34E48EF2B696E6500FCB841 /* Configuration.swift */, + C34E48EE2B696E6500FCB841 /* Llama.swift */, + C38935E22B86C0FE0037B833 /* Gemma.swift */, + C38935C72B869C7A0037B833 /* LLM.h */, + C38935E02B869F420037B833 /* LLMModel.swift */, + C38935DE2B869DD00037B833 /* Phi.swift */, + C34E48F62B69832600FCB841 /* README.md */, + C34E48ED2B696E6500FCB841 /* Util.swift */, + ); + path = LLM; + sourceTree = ""; + }; + C39273672B60697700368D5D = { + isa = PBXGroup; + children = ( + C325DE3F2B648CDB00628871 /* README.md */, + C39273822B606A9200368D5D /* Libraries */, + C39273812B606A7400368D5D /* Tools */, + C39273752B606A0A00368D5D /* Products */, + C392737E2B606A2C00368D5D /* Frameworks */, + ); + sourceTree = ""; + }; + C39273752B606A0A00368D5D /* Products */ = { + isa = PBXGroup; + children = ( + C39273742B606A0A00368D5D /* Tutorial */, + C397C58B2B62C6A9004B084D /* llm-tool */, + C34E490D2B69A92900FCB841 /* MNIST.framework */, + C34E49212B6A026F00FCB841 /* mnist-tool */, + C3288D732B6D9313009FF608 /* LinearModelTraining */, + C38935C52B869C7A0037B833 /* LLM.framework */, + ); + name = Products; + sourceTree = ""; + }; + C39273762B606A0A00368D5D /* Tutorial */ = { + isa = PBXGroup; + children = ( + C392737C2B606A1D00368D5D /* Tutorial.swift */, + ); + path = Tutorial; + sourceTree = ""; + }; + C392737E2B606A2C00368D5D /* Frameworks */ = { + isa = PBXGroup; + children = ( + ); + name = Frameworks; + sourceTree = ""; + }; + C39273812B606A7400368D5D /* Tools */ = { + isa = PBXGroup; + children = ( + C3288D742B6D9313009FF608 /* LinearModelTraining */, + C34E49222B6A026F00FCB841 /* mnist-tool */, + C34E48F32B696F0B00FCB841 /* llm-tool */, + C39273762B606A0A00368D5D /* Tutorial */, + ); + path = Tools; + sourceTree = ""; + }; + C39273822B606A9200368D5D /* Libraries */ = { + isa = PBXGroup; + children = ( + C38935C62B869C7A0037B833 /* LLM */, + C34E490E2B69A92900FCB841 /* MNIST */, + ); + path = Libraries; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXHeadersBuildPhase section */ + C34E49082B69A92900FCB841 /* Headers */ = { + isa = PBXHeadersBuildPhase; + buildActionMask = 2147483647; + files = ( + C34E49102B69A92900FCB841 /* MNIST.h in Headers */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C38935C02B869C7A0037B833 /* Headers */ = { + isa = PBXHeadersBuildPhase; + buildActionMask = 2147483647; + files = ( + C38935C82B869C7A0037B833 /* LLM.h in Headers */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXHeadersBuildPhase section */ + +/* Begin PBXNativeTarget section */ + C3288D722B6D9313009FF608 /* LinearModelTraining */ = { + isa = PBXNativeTarget; + buildConfigurationList = C3288D792B6D9313009FF608 /* Build configuration list for PBXNativeTarget "LinearModelTraining" */; + buildPhases = ( + C3288D6F2B6D9313009FF608 /* Sources */, + C3288D702B6D9313009FF608 /* Frameworks */, + C3288D712B6D9313009FF608 /* CopyFiles */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = LinearModelTraining; + packageProductDependencies = ( + C3288D7A2B6D9339009FF608 /* ArgumentParser */, + C3FBCB2E2B8520F20007E490 /* MLX */, + C3FBCB302B8520F20007E490 /* MLXNN */, + C3FBCB322B8520F20007E490 /* MLXOptimizers */, + C3FBCB342B8520F20007E490 /* MLXRandom */, + ); + productName = LinearFunctionModelTraining; + productReference = C3288D732B6D9313009FF608 /* LinearModelTraining */; + productType = "com.apple.product-type.tool"; + }; + C34E490C2B69A92900FCB841 /* MNIST */ = { + isa = PBXNativeTarget; + buildConfigurationList = C34E49112B69A92900FCB841 /* Build configuration list for PBXNativeTarget "MNIST" */; + buildPhases = ( + C34E49082B69A92900FCB841 /* Headers */, + C34E49092B69A92900FCB841 /* Sources */, + C34E490A2B69A92900FCB841 /* Frameworks */, + C34E490B2B69A92900FCB841 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = MNIST; + packageProductDependencies = ( + C34E491B2B69C43600FCB841 /* Gzip */, + C3FBCB282B8520DA0007E490 /* MLX */, + C3FBCB2A2B8520DA0007E490 /* MLXNN */, + ); + productName = MNIST; + productReference = C34E490D2B69A92900FCB841 /* MNIST.framework */; + productType = "com.apple.product-type.framework"; + }; + C34E49202B6A026F00FCB841 /* mnist-tool */ = { + isa = PBXNativeTarget; + buildConfigurationList = C34E49252B6A026F00FCB841 /* Build configuration list for PBXNativeTarget "mnist-tool" */; + buildPhases = ( + C34E491D2B6A026F00FCB841 /* Sources */, + C34E491E2B6A026F00FCB841 /* Frameworks */, + C34E491F2B6A026F00FCB841 /* CopyFiles */, + C34E492E2B6A028800FCB841 /* Embed Frameworks */, + ); + buildRules = ( + ); + dependencies = ( + C34E492D2B6A028800FCB841 /* PBXTargetDependency */, + ); + name = "mnist-tool"; + packageProductDependencies = ( + C34E49282B6A028100FCB841 /* ArgumentParser */, + C3FBCB2C2B8520E80007E490 /* MLXOptimizers */, + ); + productName = "mnist-tool"; + productReference = C34E49212B6A026F00FCB841 /* mnist-tool */; + productType = "com.apple.product-type.tool"; + }; + C38935C42B869C7A0037B833 /* LLM */ = { + isa = PBXNativeTarget; + buildConfigurationList = C38935C92B869C7A0037B833 /* Build configuration list for PBXNativeTarget "LLM" */; + buildPhases = ( + C38935C02B869C7A0037B833 /* Headers */, + C38935C12B869C7A0037B833 /* Sources */, + C38935C22B869C7A0037B833 /* Frameworks */, + C38935C32B869C7A0037B833 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = LLM; + packageProductDependencies = ( + C38935CF2B869CC40037B833 /* MLX */, + C38935D12B869CC40037B833 /* MLXNN */, + C38935D32B869CC40037B833 /* MLXRandom */, + C38935D52B869CC40037B833 /* Transformers */, + C38935DC2B869CEC0037B833 /* AsyncAlgorithms */, + ); + productName = LLM; + productReference = C38935C52B869C7A0037B833 /* LLM.framework */; + productType = "com.apple.product-type.framework"; + }; + C39273732B606A0A00368D5D /* Tutorial */ = { + isa = PBXNativeTarget; + buildConfigurationList = C39273792B606A0A00368D5D /* Build configuration list for PBXNativeTarget "Tutorial" */; + buildPhases = ( + C39273702B606A0A00368D5D /* Sources */, + C39273712B606A0A00368D5D /* Frameworks */, + C39273722B606A0A00368D5D /* CopyFiles */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = Tutorial; + packageProductDependencies = ( + C3FBCB202B8520B80007E490 /* MLX */, + ); + productName = Tutorial; + productReference = C39273742B606A0A00368D5D /* Tutorial */; + productType = "com.apple.product-type.tool"; + }; + C397C58A2B62C6A9004B084D /* llm-tool */ = { + isa = PBXNativeTarget; + buildConfigurationList = C397C58F2B62C6A9004B084D /* Build configuration list for PBXNativeTarget "llm-tool" */; + buildPhases = ( + C397C5872B62C6A9004B084D /* Sources */, + C397C5882B62C6A9004B084D /* Frameworks */, + C397C5892B62C6A9004B084D /* CopyFiles */, + C38935DB2B869CCE0037B833 /* Embed Frameworks */, + ); + buildRules = ( + ); + dependencies = ( + C38935DA2B869CCD0037B833 /* PBXTargetDependency */, + ); + name = "llm-tool"; + packageProductDependencies = ( + C397C59B2B62C6D0004B084D /* ArgumentParser */, + C382DE892B630889000F8F03 /* AsyncAlgorithms */, + ); + productName = "mistral-tool"; + productReference = C397C58B2B62C6A9004B084D /* llm-tool */; + productType = "com.apple.product-type.tool"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + C39273682B60697700368D5D /* Project object */ = { + isa = PBXProject; + attributes = { + BuildIndependentTargetsInParallel = 1; + LastSwiftUpdateCheck = 1500; + LastUpgradeCheck = 1500; + TargetAttributes = { + C3288D722B6D9313009FF608 = { + CreatedOnToolsVersion = 15.0.1; + }; + C34E490C2B69A92900FCB841 = { + CreatedOnToolsVersion = 15.0.1; + LastSwiftMigration = 1500; + }; + C34E49202B6A026F00FCB841 = { + CreatedOnToolsVersion = 15.0.1; + }; + C38935C42B869C7A0037B833 = { + CreatedOnToolsVersion = 15.2; + }; + C39273732B606A0A00368D5D = { + CreatedOnToolsVersion = 15.0.1; + }; + C397C58A2B62C6A9004B084D = { + CreatedOnToolsVersion = 15.0.1; + }; + }; + }; + buildConfigurationList = C392736B2B60697700368D5D /* Build configuration list for PBXProject "mlx-swift-examples" */; + compatibilityVersion = "Xcode 14.0"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = C39273672B60697700368D5D; + packageReferences = ( + C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */, + C382DE882B630889000F8F03 /* XCRemoteSwiftPackageReference "swift-async-algorithms" */, + C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */, + C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */, + C38935BB2B866BFA0037B833 /* XCRemoteSwiftPackageReference "swift-transformers" */, + ); + productRefGroup = C39273752B606A0A00368D5D /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + C39273732B606A0A00368D5D /* Tutorial */, + C397C58A2B62C6A9004B084D /* llm-tool */, + C38935C42B869C7A0037B833 /* LLM */, + C34E49202B6A026F00FCB841 /* mnist-tool */, + C34E490C2B69A92900FCB841 /* MNIST */, + C3288D722B6D9313009FF608 /* LinearModelTraining */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + C34E490B2B69A92900FCB841 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C38935C32B869C7A0037B833 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + C3288D6F2B6D9313009FF608 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C34E49092B69A92900FCB841 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + C34E49152B69C1E300FCB841 /* Files.swift in Sources */, + C3932D572B6A060B00A81055 /* MNIST.swift in Sources */, + C3932D592B6A0BE400A81055 /* Random.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C34E491D2B6A026F00FCB841 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + C34E49242B6A026F00FCB841 /* MNISTTool.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C38935C12B869C7A0037B833 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + C38935E12B869F420037B833 /* LLMModel.swift in Sources */, + C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */, + C38935CD2B869C870037B833 /* Configuration.swift in Sources */, + C38935DF2B869DD00037B833 /* Phi.swift in Sources */, + C38935CE2B869C870037B833 /* Util.swift in Sources */, + C38935CC2B869C870037B833 /* Llama.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C39273702B606A0A00368D5D /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + C392737D2B606A1D00368D5D /* Tutorial.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; + C397C5872B62C6A9004B084D /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXTargetDependency section */ + C34E492D2B6A028800FCB841 /* PBXTargetDependency */ = { + isa = PBXTargetDependency; + target = C34E490C2B69A92900FCB841 /* MNIST */; + targetProxy = C34E492C2B6A028800FCB841 /* PBXContainerItemProxy */; + }; + C38935DA2B869CCD0037B833 /* PBXTargetDependency */ = { + isa = PBXTargetDependency; + target = C38935C42B869C7A0037B833 /* LLM */; + targetProxy = C38935D92B869CCD0037B833 /* PBXContainerItemProxy */; + }; +/* End PBXTargetDependency section */ + +/* Begin XCBuildConfiguration section */ + C3288D772B6D9313009FF608 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + PRODUCT_NAME = "$(TARGET_NAME)"; + SDKROOT = macosx; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + }; + name = Debug; + }; + C3288D782B6D9313009FF608 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + PRODUCT_NAME = "$(TARGET_NAME)"; + SDKROOT = macosx; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_VERSION = 5.0; + }; + name = Release; + }; + C34E49122B69A92900FCB841 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COMBINE_HIDPI_IMAGES = YES; + COPY_PHASE_STRIP = NO; + CURRENT_PROJECT_VERSION = 1; + DEBUG_INFORMATION_FORMAT = dwarf; + DEFINES_MODULE = YES; + DYLIB_COMPATIBILITY_VERSION = 1; + DYLIB_CURRENT_VERSION = 1; + DYLIB_INSTALL_NAME_BASE = "@rpath"; + ENABLE_MODULE_VERIFIER = YES; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_NSHumanReadableCopyright = ""; + INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/../Frameworks", + "@loader_path/Frameworks", + ); + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MARKETING_VERSION = 1.0; + MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; + MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + PRODUCT_BUNDLE_IDENTIFIER = mlx.MNIST; + PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; + SDKROOT = macosx; + SKIP_INSTALL = YES; + SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; + SUPPORTS_MACCATALYST = NO; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + VERSIONING_SYSTEM = "apple-generic"; + VERSION_INFO_PREFIX = ""; + }; + name = Debug; + }; + C34E49132B69A92900FCB841 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COMBINE_HIDPI_IMAGES = YES; + COPY_PHASE_STRIP = NO; + CURRENT_PROJECT_VERSION = 1; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + DEFINES_MODULE = YES; + DYLIB_COMPATIBILITY_VERSION = 1; + DYLIB_CURRENT_VERSION = 1; + DYLIB_INSTALL_NAME_BASE = "@rpath"; + ENABLE_MODULE_VERIFIER = YES; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_NSHumanReadableCopyright = ""; + INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/../Frameworks", + "@loader_path/Frameworks", + ); + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MARKETING_VERSION = 1.0; + MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; + MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + PRODUCT_BUNDLE_IDENTIFIER = mlx.MNIST; + PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; + SDKROOT = macosx; + SKIP_INSTALL = YES; + SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; + SUPPORTS_MACCATALYST = NO; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + VERSIONING_SYSTEM = "apple-generic"; + VERSION_INFO_PREFIX = ""; + }; + name = Release; + }; + C34E49262B6A026F00FCB841 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + PRODUCT_NAME = "$(TARGET_NAME)"; + SDKROOT = macosx; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + }; + name = Debug; + }; + C34E49272B6A026F00FCB841 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + PRODUCT_NAME = "$(TARGET_NAME)"; + SDKROOT = macosx; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_VERSION = 5.0; + }; + name = Release; + }; + C38935CA2B869C7A0037B833 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COPY_PHASE_STRIP = NO; + CURRENT_PROJECT_VERSION = 1; + DEBUG_INFORMATION_FORMAT = dwarf; + DEFINES_MODULE = YES; + DYLIB_COMPATIBILITY_VERSION = 1; + DYLIB_CURRENT_VERSION = 1; + DYLIB_INSTALL_NAME_BASE = "@rpath"; + ENABLE_MODULE_VERIFIER = YES; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_NSHumanReadableCopyright = ""; + INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; + IPHONEOS_DEPLOYMENT_TARGET = 17.0; + LD_RUNPATH_SEARCH_PATHS = ( + "@executable_path/Frameworks", + "@loader_path/Frameworks", + ); + "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = ( + "@executable_path/../Frameworks", + "@loader_path/Frameworks", + ); + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MARKETING_VERSION = 1.0; + MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; + MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + PRODUCT_BUNDLE_IDENTIFIER = mlx.LLM; + PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; + SDKROOT = auto; + SKIP_INSTALL = YES; + SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + VERSIONING_SYSTEM = "apple-generic"; + VERSION_INFO_PREFIX = ""; + }; + name = Debug; + }; + C38935CB2B869C7A0037B833 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COPY_PHASE_STRIP = NO; + CURRENT_PROJECT_VERSION = 1; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + DEFINES_MODULE = YES; + DYLIB_COMPATIBILITY_VERSION = 1; + DYLIB_CURRENT_VERSION = 1; + DYLIB_INSTALL_NAME_BASE = "@rpath"; + ENABLE_MODULE_VERIFIER = YES; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_NSHumanReadableCopyright = ""; + INSTALL_PATH = "$(LOCAL_LIBRARY_DIR)/Frameworks"; + IPHONEOS_DEPLOYMENT_TARGET = 17.0; + LD_RUNPATH_SEARCH_PATHS = ( + "@executable_path/Frameworks", + "@loader_path/Frameworks", + ); + "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = ( + "@executable_path/../Frameworks", + "@loader_path/Frameworks", + ); + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MARKETING_VERSION = 1.0; + MODULE_VERIFIER_SUPPORTED_LANGUAGES = "objective-c objective-c++"; + MODULE_VERIFIER_SUPPORTED_LANGUAGE_STANDARDS = "gnu17 gnu++20"; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + PRODUCT_BUNDLE_IDENTIFIER = mlx.LLM; + PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; + SDKROOT = auto; + SKIP_INSTALL = YES; + SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + VERSIONING_SYSTEM = "apple-generic"; + VERSION_INFO_PREFIX = ""; + }; + name = Release; + }; + C392736C2B60697700368D5D /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ARCHS = "$(ARCHS_STANDARD)"; + EXCLUDED_ARCHS = x86_64; + ONLY_ACTIVE_ARCH = YES; + }; + name = Debug; + }; + C392736D2B60697700368D5D /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ARCHS = "$(ARCHS_STANDARD)"; + EXCLUDED_ARCHS = x86_64; + ONLY_ACTIVE_ARCH = YES; + }; + name = Release; + }; + C392737A2B606A0A00368D5D /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + PRODUCT_NAME = "$(TARGET_NAME)"; + SDKROOT = macosx; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + }; + name = Debug; + }; + C392737B2B606A0A00368D5D /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + PRODUCT_NAME = "$(TARGET_NAME)"; + SDKROOT = macosx; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_VERSION = 5.0; + }; + name = Release; + }; + C397C5902B62C6A9004B084D /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + PRODUCT_NAME = "$(TARGET_NAME)"; + SDKROOT = macosx; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + }; + name = Debug; + }; + C397C5912B62C6A9004B084D /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_STYLE = Automatic; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MACOSX_DEPLOYMENT_TARGET = 14.0; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + PRODUCT_NAME = "$(TARGET_NAME)"; + SDKROOT = macosx; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_VERSION = 5.0; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + C3288D792B6D9313009FF608 /* Build configuration list for PBXNativeTarget "LinearModelTraining" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C3288D772B6D9313009FF608 /* Debug */, + C3288D782B6D9313009FF608 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + C34E49112B69A92900FCB841 /* Build configuration list for PBXNativeTarget "MNIST" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C34E49122B69A92900FCB841 /* Debug */, + C34E49132B69A92900FCB841 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + C34E49252B6A026F00FCB841 /* Build configuration list for PBXNativeTarget "mnist-tool" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C34E49262B6A026F00FCB841 /* Debug */, + C34E49272B6A026F00FCB841 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + C38935C92B869C7A0037B833 /* Build configuration list for PBXNativeTarget "LLM" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C38935CA2B869C7A0037B833 /* Debug */, + C38935CB2B869C7A0037B833 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + C392736B2B60697700368D5D /* Build configuration list for PBXProject "mlx-swift-examples" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C392736C2B60697700368D5D /* Debug */, + C392736D2B60697700368D5D /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + C39273792B606A0A00368D5D /* Build configuration list for PBXNativeTarget "Tutorial" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C392737A2B606A0A00368D5D /* Debug */, + C392737B2B606A0A00368D5D /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + C397C58F2B62C6A9004B084D /* Build configuration list for PBXNativeTarget "llm-tool" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + C397C5902B62C6A9004B084D /* Debug */, + C397C5912B62C6A9004B084D /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + +/* Begin XCRemoteSwiftPackageReference section */ + C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/1024jp/GzipSwift"; + requirement = { + kind = upToNextMajorVersion; + minimumVersion = 6.0.1; + }; + }; + C382DE882B630889000F8F03 /* XCRemoteSwiftPackageReference "swift-async-algorithms" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/apple/swift-async-algorithms"; + requirement = { + kind = upToNextMajorVersion; + minimumVersion = 1.0.0; + }; + }; + C38935BB2B866BFA0037B833 /* XCRemoteSwiftPackageReference "swift-transformers" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/huggingface/swift-transformers"; + requirement = { + kind = upToNextMajorVersion; + minimumVersion = 0.1.2; + }; + }; + C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/apple/swift-argument-parser.git"; + requirement = { + kind = upToNextMajorVersion; + minimumVersion = 1.3.0; + }; + }; + C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/ml-explore/mlx-swift"; + requirement = { + branch = main; + kind = branch; + }; + }; +/* End XCRemoteSwiftPackageReference section */ + +/* Begin XCSwiftPackageProductDependency section */ + C3288D7A2B6D9339009FF608 /* ArgumentParser */ = { + isa = XCSwiftPackageProductDependency; + package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; + productName = ArgumentParser; + }; + C34E491B2B69C43600FCB841 /* Gzip */ = { + isa = XCSwiftPackageProductDependency; + package = C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */; + productName = Gzip; + }; + C34E49282B6A028100FCB841 /* ArgumentParser */ = { + isa = XCSwiftPackageProductDependency; + package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; + productName = ArgumentParser; + }; + C382DE892B630889000F8F03 /* AsyncAlgorithms */ = { + isa = XCSwiftPackageProductDependency; + package = C382DE882B630889000F8F03 /* XCRemoteSwiftPackageReference "swift-async-algorithms" */; + productName = AsyncAlgorithms; + }; + C38935CF2B869CC40037B833 /* MLX */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLX; + }; + C38935D12B869CC40037B833 /* MLXNN */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLXNN; + }; + C38935D32B869CC40037B833 /* MLXRandom */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLXRandom; + }; + C38935D52B869CC40037B833 /* Transformers */ = { + isa = XCSwiftPackageProductDependency; + package = C38935BB2B866BFA0037B833 /* XCRemoteSwiftPackageReference "swift-transformers" */; + productName = Transformers; + }; + C38935DC2B869CEC0037B833 /* AsyncAlgorithms */ = { + isa = XCSwiftPackageProductDependency; + package = C382DE882B630889000F8F03 /* XCRemoteSwiftPackageReference "swift-async-algorithms" */; + productName = AsyncAlgorithms; + }; + C397C59B2B62C6D0004B084D /* ArgumentParser */ = { + isa = XCSwiftPackageProductDependency; + package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; + productName = ArgumentParser; + }; + C3FBCB202B8520B80007E490 /* MLX */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLX; + }; + C3FBCB282B8520DA0007E490 /* MLX */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLX; + }; + C3FBCB2A2B8520DA0007E490 /* MLXNN */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLXNN; + }; + C3FBCB2C2B8520E80007E490 /* MLXOptimizers */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLXOptimizers; + }; + C3FBCB2E2B8520F20007E490 /* MLX */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLX; + }; + C3FBCB302B8520F20007E490 /* MLXNN */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLXNN; + }; + C3FBCB322B8520F20007E490 /* MLXOptimizers */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLXOptimizers; + }; + C3FBCB342B8520F20007E490 /* MLXRandom */ = { + isa = XCSwiftPackageProductDependency; + package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */; + productName = MLXRandom; + }; +/* End XCSwiftPackageProductDependency section */ + }; + rootObject = C39273682B60697700368D5D /* Project object */; +} diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/mlx-swift-examples.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 0000000..919434a --- /dev/null +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 0000000..18d9810 --- /dev/null +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved new file mode 100644 index 0000000..f436e10 --- /dev/null +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -0,0 +1,68 @@ +{ + "pins" : [ + { + "identity" : "gzipswift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/1024jp/GzipSwift", + "state" : { + "revision" : "731037f6cc2be2ec01562f6597c1d0aa3fe6fd05", + "version" : "6.0.1" + } + }, + { + "identity" : "mlx-swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ml-explore/mlx-swift", + "state" : { + "branch" : "main", + "revision" : "cadf5f8187ac0894e66cd288217e2eda9f2c933d" + } + }, + { + "identity" : "swift-argument-parser", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-argument-parser.git", + "state" : { + "revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-async-algorithms", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-async-algorithms", + "state" : { + "revision" : "da4e36f86544cdf733a40d59b3a2267e3a7bbf36", + "version" : "1.0.0" + } + }, + { + "identity" : "swift-collections", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-collections.git", + "state" : { + "revision" : "d029d9d39c87bed85b1c50adee7c41795261a192", + "version" : "1.0.6" + } + }, + { + "identity" : "swift-numerics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-numerics", + "state" : { + "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", + "version" : "1.0.2" + } + }, + { + "identity" : "swift-transformers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-transformers", + "state" : { + "revision" : "564442fba36b0b694d730a62d0593e5f54043b55", + "version" : "0.1.2" + } + } + ], + "version" : 2 +}