diff --git a/Libraries/LLM/Configuration.swift b/Libraries/LLM/Configuration.swift index dd0519e..258cb6d 100644 --- a/Libraries/LLM/Configuration.swift +++ b/Libraries/LLM/Configuration.swift @@ -58,7 +58,7 @@ public enum ModelType: String, Codable { return GemmaModel(configuration) case .gemma2: let configuration = try JSONDecoder().decode( - GemmaConfiguration.self, from: Data(contentsOf: configuration)) + Gemma2Configuration.self, from: Data(contentsOf: configuration)) return Gemma2Model(configuration) case .qwen2: let configuration = try JSONDecoder().decode( diff --git a/Libraries/LLM/Gemma.swift b/Libraries/LLM/Gemma.swift index 8d91bb2..14a96b1 100644 --- a/Libraries/LLM/Gemma.swift +++ b/Libraries/LLM/Gemma.swift @@ -262,111 +262,3 @@ extension GemmaModel: LoRAModel { model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } } } - -// Gemma 2 - -// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py - -// Minimal changes from Gemma TransformerBlock -private class Gemma2TransformerBlock: Module { - - @ModuleInfo(key: "self_attn") var attention: Attention - let mlp: MLP - - @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm - @ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm - @ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm - @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm - - public init(_ args: GemmaConfiguration) { - self._attention.wrappedValue = Attention(args) - self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) - self._inputLayerNorm.wrappedValue = RMSNorm( - dimensions: args.hiddenSize, eps: args.rmsNormEps) - self._preFeedforwardLayerNorm.wrappedValue = RMSNorm( - dimensions: args.hiddenSize, eps: args.rmsNormEps) - self._postFeedforwardLayerNorm.wrappedValue = RMSNorm( - dimensions: args.hiddenSize, eps: args.rmsNormEps) - self._postAttentionLayerNorm.wrappedValue = RMSNorm( - dimensions: args.hiddenSize, eps: args.rmsNormEps) - } - - public func callAsFunction( - _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil - ) -> (MLXArray, (MLXArray, MLXArray)) { - var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) - let h = x + postAttentionLayerNorm(r) - r = mlp(preFeedforwardLayerNorm(h)) - let out = h + postFeedforwardLayerNorm(r) - return (out, cache) - } -} - -// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner -public class Gemma2ModelInner: Module { - - @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding - - fileprivate let layers: [Gemma2TransformerBlock] - fileprivate let norm: RMSNorm - - let hiddenScale: Float - - public init(_ args: GemmaConfiguration) { - precondition(args.vocabularySize > 0) - - self._embedTokens.wrappedValue = Embedding( - embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) - - self.hiddenScale = pow(Float(args.hiddenSize), 0.5) - - self.layers = (0 ..< args.hiddenLayers) - .map { _ in - Gemma2TransformerBlock(args) - } - self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) - } - - public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( - MLXArray, [(MLXArray, MLXArray)] - ) { - var h = embedTokens(inputs) - h = h * hiddenScale - - var mask: MLXArray? = nil - if h.dim(1) > 1 { - mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1)) - mask = mask?.asType(h.dtype) - } - - var newCache = [(MLXArray, MLXArray)]() - - for (i, layer) in layers.enumerated() { - var cacheUpdate: (MLXArray, MLXArray) - (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i]) - newCache.append(cacheUpdate) - } - - return (norm(h), newCache) - } -} - -// Uses Gemma2ModelInner, otherwise same as GemmaModel -public class Gemma2Model: Module, LLMModel { - - public let vocabularySize: Int - let model: Gemma2ModelInner - - public init(_ args: GemmaConfiguration) { - self.vocabularySize = args.vocabularySize - self.model = Gemma2ModelInner(args) - } - - public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( - MLXArray, [(MLXArray, MLXArray)] - ) { - var (out, cache) = model(inputs, cache: cache) - out = model.embedTokens.asLinear(out) - return (out, cache) - } -} diff --git a/Libraries/LLM/Gemma2.swift b/Libraries/LLM/Gemma2.swift new file mode 100644 index 0000000..080cb99 --- /dev/null +++ b/Libraries/LLM/Gemma2.swift @@ -0,0 +1,309 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX +import MLXFast +import MLXNN + +// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py + +// specialized norm for gemma +private class RMSNorm: Module, UnaryLayer { + let weight: MLXArray + let eps: Float + + public init(dimensions: Int, eps: Float = 1e-5) { + self.weight = MLXArray.ones([dimensions]) + self.eps = eps + super.init() + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + return MLXFast.rmsNorm(x, weight: 1.0 + self.weight, eps: self.eps) + } +} + +private class Attention: Module { + + let args: Gemma2Configuration + let scale: Float + let logitSoftCap: Float + let headDim: Int + + @ModuleInfo(key: "q_proj") var wq: Linear + @ModuleInfo(key: "k_proj") var wk: Linear + @ModuleInfo(key: "v_proj") var wv: Linear + @ModuleInfo(key: "o_proj") var wo: Linear + + let rope: RoPE + + public init(_ args: Gemma2Configuration) { + self.args = args + + let dim = args.hiddenSize + let heads = args.attentionHeads + let kvHeads = args.kvHeads + + let headDim = args.headDimensions + self.headDim = headDim + self.scale = pow(Float(args.queryPreAttnScalar), -0.5) + self.logitSoftCap = args.attnLogitSoftcapping + + self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false) + self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) + self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false) + self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false) + + self.rope = RoPE( + dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + let (B, L) = (x.dim(0), x.dim(1)) + + var queries = wq(x) + var keys = wk(x) + var values = wv(x) + + // prepare the queries, keys and values for the attention computation + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + if let (keyCache, valueCache) = cache { + queries = rope(queries, offset: keyCache.dim(2)) + keys = rope(keys, offset: keyCache.dim(2)) + keys = concatenated([keyCache, keys], axis: 2) + values = concatenated([valueCache, values], axis: 2) + } else { + queries = rope(queries) + keys = rope(keys) + } + + let newCache = (keys, values) + + let repeats = self.args.attentionHeads / self.args.kvHeads + if repeats > 1 { + queries = queries.reshaped( + [B, self.args.kvHeads, repeats, L, self.headDim] + ) + keys = expandedDimensions(keys, axes: [2]) + values = expandedDimensions(values, axes: [2]) + } + + var scores = matmul(queries, keys.swappedAxes(-1, -2)) + scores = tanh(scores / self.logitSoftCap) * self.logitSoftCap + + if mask != nil { + scores = scores + mask! + } + scores = softmax(scores, axis: -1, precise: true) + var output = matmul(scores, values) + if repeats > 1 { + output = output.reshaped([B, self.args.attentionHeads, L, self.headDim]) + } + output = output.transposed(0, 2, 1, 3).reshaped(B, L, -1) + return (wo(output), newCache) + } +} + +private class MLP: Module, UnaryLayer { + + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "down_proj") var down: Linear + @ModuleInfo(key: "up_proj") var up: Linear + + public init(dimensions: Int, hiddenDimensions: Int) { + self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false) + self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false) + } + + public func callAsFunction(_ x: MLXArray) -> MLXArray { + down(gelu(gate(x)) * up(x)) + } +} + +// Minimal changes from Gemma TransformerBlock +private class TransformerBlock: Module { + + @ModuleInfo(key: "self_attn") var attention: Attention + let mlp: MLP + + @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm + @ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm + @ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm + @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm + + public init(_ args: Gemma2Configuration) { + self._attention.wrappedValue = Attention(args) + self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize) + self._inputLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._preFeedforwardLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postFeedforwardLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + self._postAttentionLayerNorm.wrappedValue = RMSNorm( + dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction( + _ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil + ) -> (MLXArray, (MLXArray, MLXArray)) { + var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache) + let h = x + postAttentionLayerNorm(r) + r = mlp(preFeedforwardLayerNorm(h)) + let out = h + postFeedforwardLayerNorm(r) + return (out, cache) + } +} + +// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner +public class ModelInner: Module { + + @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding + + fileprivate let layers: [TransformerBlock] + fileprivate let norm: RMSNorm + + let hiddenScale: Float + + public init(_ args: Gemma2Configuration) { + precondition(args.vocabularySize > 0) + + self._embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + + self.hiddenScale = pow(Float(args.hiddenSize), 0.5) + + self.layers = (0 ..< args.hiddenLayers) + .map { _ in + TransformerBlock(args) + } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var h = embedTokens(inputs) + h = h * hiddenScale + + var mask: MLXArray? = nil + if h.dim(1) > 1 { + mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1)) + mask = mask?.asType(h.dtype) + } + + var newCache = [(MLXArray, MLXArray)]() + + for (i, layer) in layers.enumerated() { + var cacheUpdate: (MLXArray, MLXArray) + (h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i]) + newCache.append(cacheUpdate) + } + + return (norm(h), newCache) + } +} + +// Uses Gemma2ModelInner, otherwise same as GemmaModel +public class Gemma2Model: Module, LLMModel { + + public let vocabularySize: Int + let model: ModelInner + let logitSoftCap: Float + + public init(_ args: Gemma2Configuration) { + self.vocabularySize = args.vocabularySize + self.model = ModelInner(args) + self.logitSoftCap = args.finalLogitSoftcapping + } + + public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> ( + MLXArray, [(MLXArray, MLXArray)] + ) { + var (out, cache) = model(inputs, cache: cache) + out = model.embedTokens.asLinear(out) + out = tanh(out / self.logitSoftCap) * self.logitSoftCap + return (out, cache) + } +} + +public struct Gemma2Configuration: Codable { + + var hiddenSize: Int + var hiddenLayers: Int + var intermediateSize: Int + var attentionHeads: Int + var headDimensions: Int + var rmsNormEps: Float + var vocabularySize: Int + var kvHeads: Int + var ropeTheta: Float = 10_000 + var ropeTraditional: Bool = false + var attnLogitSoftcapping: Float = 50.0 + var finalLogitSoftcapping: Float = 30.0 + var queryPreAttnScalar: Int = 256 + + enum CodingKeys: String, CodingKey { + case hiddenSize = "hidden_size" + case hiddenLayers = "num_hidden_layers" + case intermediateSize = "intermediate_size" + case attentionHeads = "num_attention_heads" + case headDimensions = "head_dim" + case rmsNormEps = "rms_norm_eps" + case vocabularySize = "vocab_size" + case kvHeads = "num_key_value_heads" + case ropeTheta = "rope_theta" + case ropeTraditional = "rope_traditional" + case attnLogitSoftcapping = "attn_logit_softcapping" + case finalLogitSoftcapping = "final_logit_softcapping" + case queryPreAttnScalar = "query_pre_attn_scalar" + } + + public init(from decoder: Decoder) throws { + // custom implementation to handle optional keys with required values + let container: KeyedDecodingContainer = try decoder.container( + keyedBy: CodingKeys.self) + + self.hiddenSize = try container.decode( + Int.self, forKey: CodingKeys.hiddenSize) + self.hiddenLayers = try container.decode( + Int.self, forKey: CodingKeys.hiddenLayers) + self.intermediateSize = try container.decode( + Int.self, forKey: CodingKeys.intermediateSize) + self.attentionHeads = try container.decode( + Int.self, forKey: CodingKeys.attentionHeads) + self.headDimensions = try container.decode( + Int.self, forKey: CodingKeys.headDimensions) + self.rmsNormEps = try container.decode( + Float.self, forKey: CodingKeys.rmsNormEps) + self.vocabularySize = try container.decode( + Int.self, forKey: CodingKeys.vocabularySize) + self.kvHeads = try container.decode(Int.self, forKey: CodingKeys.kvHeads) + self.ropeTheta = + try container.decodeIfPresent(Float.self, forKey: CodingKeys.ropeTheta) + ?? 10_000 + self.ropeTraditional = + try container.decodeIfPresent( + Bool.self, forKey: CodingKeys.ropeTraditional) ?? false + self.attnLogitSoftcapping = try container.decode( + Float.self, forKey: CodingKeys.attnLogitSoftcapping) + self.finalLogitSoftcapping = try container.decode( + Float.self, forKey: CodingKeys.finalLogitSoftcapping) + self.queryPreAttnScalar = try container.decode( + Int.self, forKey: CodingKeys.queryPreAttnScalar) + } +} + +// MARK: - LoRA + +extension Gemma2Model: LoRAModel { + public func loraLinearLayers() -> LoRALinearLayers { + model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } + } +} diff --git a/Libraries/LLM/Models.swift b/Libraries/LLM/Models.swift index 9a9b7ba..62581d4 100644 --- a/Libraries/LLM/Models.swift +++ b/Libraries/LLM/Models.swift @@ -168,7 +168,7 @@ extension ModelConfiguration { defaultPrompt: "what is the difference between lettuce and cabbage?" ) { prompt in - "user \(prompt)model" + "user\n\(prompt)\nmodel\n" } public static let gemma_2_9b_it_4bit = ModelConfiguration( @@ -178,6 +178,17 @@ extension ModelConfiguration { // https://www.promptingguide.ai/models/gemma defaultPrompt: "What is the difference between lettuce and cabbage?" + ) { prompt in + "user\n\(prompt)\nmodel\n" + } + + public static let gemma_2_2b_it_4bit = ModelConfiguration( + id: "mlx-community/gemma-2-2b-it-4bit", + overrideTokenizer: "PreTrainedTokenizer", + + // https://www.promptingguide.ai/models/gemma + defaultPrompt: "What is the difference between lettuce and cabbage?" + ) { prompt in "user \(prompt)model" } diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 8517ff5..12e79d9 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -8,6 +8,7 @@ /* Begin PBXBuildFile section */ 12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 12305EAE2B9D864400C92FEE /* PredictionView.swift */; }; + 1C55317A2C5AAB4E00B07ECD /* Gemma2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1C5531792C5AAB4E00B07ECD /* Gemma2.swift */; }; 1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */; }; 525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; }; 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; }; @@ -218,6 +219,7 @@ /* Begin PBXFileReference section */ 12305EAE2B9D864400C92FEE /* PredictionView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PredictionView.swift; sourceTree = ""; }; + 1C5531792C5AAB4E00B07ECD /* Gemma2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Gemma2.swift; sourceTree = ""; }; 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Phi3.swift; sourceTree = ""; }; 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = ""; }; 52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = ""; }; @@ -480,6 +482,7 @@ C3A8B3AB2B9283150002EFB8 /* Models.swift */, C34E48EE2B696E6500FCB841 /* Llama.swift */, C38935E22B86C0FE0037B833 /* Gemma.swift */, + 1C5531792C5AAB4E00B07ECD /* Gemma2.swift */, C38935C72B869C7A0037B833 /* LLM.h */, C38935E02B869F420037B833 /* LLMModel.swift */, C38935DE2B869DD00037B833 /* Phi.swift */, @@ -1014,6 +1017,7 @@ C38935CE2B869C870037B833 /* Load.swift in Sources */, C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */, C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */, + 1C55317A2C5AAB4E00B07ECD /* Gemma2.swift in Sources */, C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */, C38935CC2B869C870037B833 /* Llama.swift in Sources */, 52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */,