Some fixes for gemma2 (#99)
* some fixes for gemma2 * format * fixes * format
This commit is contained in:
@@ -58,7 +58,7 @@ public enum ModelType: String, Codable {
|
|||||||
return GemmaModel(configuration)
|
return GemmaModel(configuration)
|
||||||
case .gemma2:
|
case .gemma2:
|
||||||
let configuration = try JSONDecoder().decode(
|
let configuration = try JSONDecoder().decode(
|
||||||
GemmaConfiguration.self, from: Data(contentsOf: configuration))
|
Gemma2Configuration.self, from: Data(contentsOf: configuration))
|
||||||
return Gemma2Model(configuration)
|
return Gemma2Model(configuration)
|
||||||
case .qwen2:
|
case .qwen2:
|
||||||
let configuration = try JSONDecoder().decode(
|
let configuration = try JSONDecoder().decode(
|
||||||
|
|||||||
@@ -262,111 +262,3 @@ extension GemmaModel: LoRAModel {
|
|||||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gemma 2
|
|
||||||
|
|
||||||
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
|
|
||||||
|
|
||||||
// Minimal changes from Gemma TransformerBlock
|
|
||||||
private class Gemma2TransformerBlock: Module {
|
|
||||||
|
|
||||||
@ModuleInfo(key: "self_attn") var attention: Attention
|
|
||||||
let mlp: MLP
|
|
||||||
|
|
||||||
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
|
|
||||||
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm
|
|
||||||
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm
|
|
||||||
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
|
|
||||||
|
|
||||||
public init(_ args: GemmaConfiguration) {
|
|
||||||
self._attention.wrappedValue = Attention(args)
|
|
||||||
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
|
|
||||||
self._inputLayerNorm.wrappedValue = RMSNorm(
|
|
||||||
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
|
||||||
self._preFeedforwardLayerNorm.wrappedValue = RMSNorm(
|
|
||||||
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
|
||||||
self._postFeedforwardLayerNorm.wrappedValue = RMSNorm(
|
|
||||||
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
|
||||||
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
|
|
||||||
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
|
||||||
}
|
|
||||||
|
|
||||||
public func callAsFunction(
|
|
||||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
|
||||||
) -> (MLXArray, (MLXArray, MLXArray)) {
|
|
||||||
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
|
||||||
let h = x + postAttentionLayerNorm(r)
|
|
||||||
r = mlp(preFeedforwardLayerNorm(h))
|
|
||||||
let out = h + postFeedforwardLayerNorm(r)
|
|
||||||
return (out, cache)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner
|
|
||||||
public class Gemma2ModelInner: Module {
|
|
||||||
|
|
||||||
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
|
||||||
|
|
||||||
fileprivate let layers: [Gemma2TransformerBlock]
|
|
||||||
fileprivate let norm: RMSNorm
|
|
||||||
|
|
||||||
let hiddenScale: Float
|
|
||||||
|
|
||||||
public init(_ args: GemmaConfiguration) {
|
|
||||||
precondition(args.vocabularySize > 0)
|
|
||||||
|
|
||||||
self._embedTokens.wrappedValue = Embedding(
|
|
||||||
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
|
||||||
|
|
||||||
self.hiddenScale = pow(Float(args.hiddenSize), 0.5)
|
|
||||||
|
|
||||||
self.layers = (0 ..< args.hiddenLayers)
|
|
||||||
.map { _ in
|
|
||||||
Gemma2TransformerBlock(args)
|
|
||||||
}
|
|
||||||
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
|
||||||
}
|
|
||||||
|
|
||||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
|
||||||
MLXArray, [(MLXArray, MLXArray)]
|
|
||||||
) {
|
|
||||||
var h = embedTokens(inputs)
|
|
||||||
h = h * hiddenScale
|
|
||||||
|
|
||||||
var mask: MLXArray? = nil
|
|
||||||
if h.dim(1) > 1 {
|
|
||||||
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
|
|
||||||
mask = mask?.asType(h.dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
var newCache = [(MLXArray, MLXArray)]()
|
|
||||||
|
|
||||||
for (i, layer) in layers.enumerated() {
|
|
||||||
var cacheUpdate: (MLXArray, MLXArray)
|
|
||||||
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
|
|
||||||
newCache.append(cacheUpdate)
|
|
||||||
}
|
|
||||||
|
|
||||||
return (norm(h), newCache)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uses Gemma2ModelInner, otherwise same as GemmaModel
|
|
||||||
public class Gemma2Model: Module, LLMModel {
|
|
||||||
|
|
||||||
public let vocabularySize: Int
|
|
||||||
let model: Gemma2ModelInner
|
|
||||||
|
|
||||||
public init(_ args: GemmaConfiguration) {
|
|
||||||
self.vocabularySize = args.vocabularySize
|
|
||||||
self.model = Gemma2ModelInner(args)
|
|
||||||
}
|
|
||||||
|
|
||||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
|
||||||
MLXArray, [(MLXArray, MLXArray)]
|
|
||||||
) {
|
|
||||||
var (out, cache) = model(inputs, cache: cache)
|
|
||||||
out = model.embedTokens.asLinear(out)
|
|
||||||
return (out, cache)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
309
Libraries/LLM/Gemma2.swift
Normal file
309
Libraries/LLM/Gemma2.swift
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXFast
|
||||||
|
import MLXNN
|
||||||
|
|
||||||
|
// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/gemma2.py
|
||||||
|
|
||||||
|
// specialized norm for gemma
|
||||||
|
private class RMSNorm: Module, UnaryLayer {
|
||||||
|
let weight: MLXArray
|
||||||
|
let eps: Float
|
||||||
|
|
||||||
|
public init(dimensions: Int, eps: Float = 1e-5) {
|
||||||
|
self.weight = MLXArray.ones([dimensions])
|
||||||
|
self.eps = eps
|
||||||
|
super.init()
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
return MLXFast.rmsNorm(x, weight: 1.0 + self.weight, eps: self.eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class Attention: Module {
|
||||||
|
|
||||||
|
let args: Gemma2Configuration
|
||||||
|
let scale: Float
|
||||||
|
let logitSoftCap: Float
|
||||||
|
let headDim: Int
|
||||||
|
|
||||||
|
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||||
|
@ModuleInfo(key: "k_proj") var wk: Linear
|
||||||
|
@ModuleInfo(key: "v_proj") var wv: Linear
|
||||||
|
@ModuleInfo(key: "o_proj") var wo: Linear
|
||||||
|
|
||||||
|
let rope: RoPE
|
||||||
|
|
||||||
|
public init(_ args: Gemma2Configuration) {
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
let dim = args.hiddenSize
|
||||||
|
let heads = args.attentionHeads
|
||||||
|
let kvHeads = args.kvHeads
|
||||||
|
|
||||||
|
let headDim = args.headDimensions
|
||||||
|
self.headDim = headDim
|
||||||
|
self.scale = pow(Float(args.queryPreAttnScalar), -0.5)
|
||||||
|
self.logitSoftCap = args.attnLogitSoftcapping
|
||||||
|
|
||||||
|
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
|
||||||
|
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
||||||
|
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
||||||
|
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
|
||||||
|
|
||||||
|
self.rope = RoPE(
|
||||||
|
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
|
let (B, L) = (x.dim(0), x.dim(1))
|
||||||
|
|
||||||
|
var queries = wq(x)
|
||||||
|
var keys = wk(x)
|
||||||
|
var values = wv(x)
|
||||||
|
|
||||||
|
// prepare the queries, keys and values for the attention computation
|
||||||
|
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if let (keyCache, valueCache) = cache {
|
||||||
|
queries = rope(queries, offset: keyCache.dim(2))
|
||||||
|
keys = rope(keys, offset: keyCache.dim(2))
|
||||||
|
keys = concatenated([keyCache, keys], axis: 2)
|
||||||
|
values = concatenated([valueCache, values], axis: 2)
|
||||||
|
} else {
|
||||||
|
queries = rope(queries)
|
||||||
|
keys = rope(keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
let newCache = (keys, values)
|
||||||
|
|
||||||
|
let repeats = self.args.attentionHeads / self.args.kvHeads
|
||||||
|
if repeats > 1 {
|
||||||
|
queries = queries.reshaped(
|
||||||
|
[B, self.args.kvHeads, repeats, L, self.headDim]
|
||||||
|
)
|
||||||
|
keys = expandedDimensions(keys, axes: [2])
|
||||||
|
values = expandedDimensions(values, axes: [2])
|
||||||
|
}
|
||||||
|
|
||||||
|
var scores = matmul(queries, keys.swappedAxes(-1, -2))
|
||||||
|
scores = tanh(scores / self.logitSoftCap) * self.logitSoftCap
|
||||||
|
|
||||||
|
if mask != nil {
|
||||||
|
scores = scores + mask!
|
||||||
|
}
|
||||||
|
scores = softmax(scores, axis: -1, precise: true)
|
||||||
|
var output = matmul(scores, values)
|
||||||
|
if repeats > 1 {
|
||||||
|
output = output.reshaped([B, self.args.attentionHeads, L, self.headDim])
|
||||||
|
}
|
||||||
|
output = output.transposed(0, 2, 1, 3).reshaped(B, L, -1)
|
||||||
|
return (wo(output), newCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class MLP: Module, UnaryLayer {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "gate_proj") var gate: Linear
|
||||||
|
@ModuleInfo(key: "down_proj") var down: Linear
|
||||||
|
@ModuleInfo(key: "up_proj") var up: Linear
|
||||||
|
|
||||||
|
public init(dimensions: Int, hiddenDimensions: Int) {
|
||||||
|
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||||
|
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
|
||||||
|
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||||
|
down(gelu(gate(x)) * up(x))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Minimal changes from Gemma TransformerBlock
|
||||||
|
private class TransformerBlock: Module {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "self_attn") var attention: Attention
|
||||||
|
let mlp: MLP
|
||||||
|
|
||||||
|
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
|
||||||
|
@ModuleInfo(key: "pre_feedforward_layernorm") var preFeedforwardLayerNorm: RMSNorm
|
||||||
|
@ModuleInfo(key: "post_feedforward_layernorm") var postFeedforwardLayerNorm: RMSNorm
|
||||||
|
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
|
||||||
|
|
||||||
|
public init(_ args: Gemma2Configuration) {
|
||||||
|
self._attention.wrappedValue = Attention(args)
|
||||||
|
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
|
||||||
|
self._inputLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
self._preFeedforwardLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
self._postFeedforwardLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
|
||||||
|
dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(
|
||||||
|
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||||
|
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||||
|
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
|
||||||
|
let h = x + postAttentionLayerNorm(r)
|
||||||
|
r = mlp(preFeedforwardLayerNorm(h))
|
||||||
|
let out = h + postFeedforwardLayerNorm(r)
|
||||||
|
return (out, cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner
|
||||||
|
public class ModelInner: Module {
|
||||||
|
|
||||||
|
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
||||||
|
|
||||||
|
fileprivate let layers: [TransformerBlock]
|
||||||
|
fileprivate let norm: RMSNorm
|
||||||
|
|
||||||
|
let hiddenScale: Float
|
||||||
|
|
||||||
|
public init(_ args: Gemma2Configuration) {
|
||||||
|
precondition(args.vocabularySize > 0)
|
||||||
|
|
||||||
|
self._embedTokens.wrappedValue = Embedding(
|
||||||
|
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
||||||
|
|
||||||
|
self.hiddenScale = pow(Float(args.hiddenSize), 0.5)
|
||||||
|
|
||||||
|
self.layers = (0 ..< args.hiddenLayers)
|
||||||
|
.map { _ in
|
||||||
|
TransformerBlock(args)
|
||||||
|
}
|
||||||
|
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
) {
|
||||||
|
var h = embedTokens(inputs)
|
||||||
|
h = h * hiddenScale
|
||||||
|
|
||||||
|
var mask: MLXArray? = nil
|
||||||
|
if h.dim(1) > 1 {
|
||||||
|
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
|
||||||
|
mask = mask?.asType(h.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
var newCache = [(MLXArray, MLXArray)]()
|
||||||
|
|
||||||
|
for (i, layer) in layers.enumerated() {
|
||||||
|
var cacheUpdate: (MLXArray, MLXArray)
|
||||||
|
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
|
||||||
|
newCache.append(cacheUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (norm(h), newCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uses Gemma2ModelInner, otherwise same as GemmaModel
|
||||||
|
public class Gemma2Model: Module, LLMModel {
|
||||||
|
|
||||||
|
public let vocabularySize: Int
|
||||||
|
let model: ModelInner
|
||||||
|
let logitSoftCap: Float
|
||||||
|
|
||||||
|
public init(_ args: Gemma2Configuration) {
|
||||||
|
self.vocabularySize = args.vocabularySize
|
||||||
|
self.model = ModelInner(args)
|
||||||
|
self.logitSoftCap = args.finalLogitSoftcapping
|
||||||
|
}
|
||||||
|
|
||||||
|
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||||
|
MLXArray, [(MLXArray, MLXArray)]
|
||||||
|
) {
|
||||||
|
var (out, cache) = model(inputs, cache: cache)
|
||||||
|
out = model.embedTokens.asLinear(out)
|
||||||
|
out = tanh(out / self.logitSoftCap) * self.logitSoftCap
|
||||||
|
return (out, cache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public struct Gemma2Configuration: Codable {
|
||||||
|
|
||||||
|
var hiddenSize: Int
|
||||||
|
var hiddenLayers: Int
|
||||||
|
var intermediateSize: Int
|
||||||
|
var attentionHeads: Int
|
||||||
|
var headDimensions: Int
|
||||||
|
var rmsNormEps: Float
|
||||||
|
var vocabularySize: Int
|
||||||
|
var kvHeads: Int
|
||||||
|
var ropeTheta: Float = 10_000
|
||||||
|
var ropeTraditional: Bool = false
|
||||||
|
var attnLogitSoftcapping: Float = 50.0
|
||||||
|
var finalLogitSoftcapping: Float = 30.0
|
||||||
|
var queryPreAttnScalar: Int = 256
|
||||||
|
|
||||||
|
enum CodingKeys: String, CodingKey {
|
||||||
|
case hiddenSize = "hidden_size"
|
||||||
|
case hiddenLayers = "num_hidden_layers"
|
||||||
|
case intermediateSize = "intermediate_size"
|
||||||
|
case attentionHeads = "num_attention_heads"
|
||||||
|
case headDimensions = "head_dim"
|
||||||
|
case rmsNormEps = "rms_norm_eps"
|
||||||
|
case vocabularySize = "vocab_size"
|
||||||
|
case kvHeads = "num_key_value_heads"
|
||||||
|
case ropeTheta = "rope_theta"
|
||||||
|
case ropeTraditional = "rope_traditional"
|
||||||
|
case attnLogitSoftcapping = "attn_logit_softcapping"
|
||||||
|
case finalLogitSoftcapping = "final_logit_softcapping"
|
||||||
|
case queryPreAttnScalar = "query_pre_attn_scalar"
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(from decoder: Decoder) throws {
|
||||||
|
// custom implementation to handle optional keys with required values
|
||||||
|
let container: KeyedDecodingContainer<CodingKeys> = try decoder.container(
|
||||||
|
keyedBy: CodingKeys.self)
|
||||||
|
|
||||||
|
self.hiddenSize = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.hiddenSize)
|
||||||
|
self.hiddenLayers = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.hiddenLayers)
|
||||||
|
self.intermediateSize = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.intermediateSize)
|
||||||
|
self.attentionHeads = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.attentionHeads)
|
||||||
|
self.headDimensions = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.headDimensions)
|
||||||
|
self.rmsNormEps = try container.decode(
|
||||||
|
Float.self, forKey: CodingKeys.rmsNormEps)
|
||||||
|
self.vocabularySize = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.vocabularySize)
|
||||||
|
self.kvHeads = try container.decode(Int.self, forKey: CodingKeys.kvHeads)
|
||||||
|
self.ropeTheta =
|
||||||
|
try container.decodeIfPresent(Float.self, forKey: CodingKeys.ropeTheta)
|
||||||
|
?? 10_000
|
||||||
|
self.ropeTraditional =
|
||||||
|
try container.decodeIfPresent(
|
||||||
|
Bool.self, forKey: CodingKeys.ropeTraditional) ?? false
|
||||||
|
self.attnLogitSoftcapping = try container.decode(
|
||||||
|
Float.self, forKey: CodingKeys.attnLogitSoftcapping)
|
||||||
|
self.finalLogitSoftcapping = try container.decode(
|
||||||
|
Float.self, forKey: CodingKeys.finalLogitSoftcapping)
|
||||||
|
self.queryPreAttnScalar = try container.decode(
|
||||||
|
Int.self, forKey: CodingKeys.queryPreAttnScalar)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - LoRA
|
||||||
|
|
||||||
|
extension Gemma2Model: LoRAModel {
|
||||||
|
public func loraLinearLayers() -> LoRALinearLayers {
|
||||||
|
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -168,7 +168,7 @@ extension ModelConfiguration {
|
|||||||
defaultPrompt: "what is the difference between lettuce and cabbage?"
|
defaultPrompt: "what is the difference between lettuce and cabbage?"
|
||||||
|
|
||||||
) { prompt in
|
) { prompt in
|
||||||
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
"<start_of_turn>user\n\(prompt)<end_of_turn>\n<start_of_turn>model\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
public static let gemma_2_9b_it_4bit = ModelConfiguration(
|
public static let gemma_2_9b_it_4bit = ModelConfiguration(
|
||||||
@@ -178,6 +178,17 @@ extension ModelConfiguration {
|
|||||||
// https://www.promptingguide.ai/models/gemma
|
// https://www.promptingguide.ai/models/gemma
|
||||||
defaultPrompt: "What is the difference between lettuce and cabbage?"
|
defaultPrompt: "What is the difference between lettuce and cabbage?"
|
||||||
|
|
||||||
|
) { prompt in
|
||||||
|
"<start_of_turn>user\n\(prompt)<end_of_turn>\n<start_of_turn>model\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
public static let gemma_2_2b_it_4bit = ModelConfiguration(
|
||||||
|
id: "mlx-community/gemma-2-2b-it-4bit",
|
||||||
|
overrideTokenizer: "PreTrainedTokenizer",
|
||||||
|
|
||||||
|
// https://www.promptingguide.ai/models/gemma
|
||||||
|
defaultPrompt: "What is the difference between lettuce and cabbage?"
|
||||||
|
|
||||||
) { prompt in
|
) { prompt in
|
||||||
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
/* Begin PBXBuildFile section */
|
/* Begin PBXBuildFile section */
|
||||||
12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 12305EAE2B9D864400C92FEE /* PredictionView.swift */; };
|
12305EAF2B9D864400C92FEE /* PredictionView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 12305EAE2B9D864400C92FEE /* PredictionView.swift */; };
|
||||||
|
1C55317A2C5AAB4E00B07ECD /* Gemma2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1C5531792C5AAB4E00B07ECD /* Gemma2.swift */; };
|
||||||
1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */; };
|
1CD79C702BD80DE100B6C06F /* Phi3.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1CD79C6F2BD80DE100B6C06F /* Phi3.swift */; };
|
||||||
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; };
|
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */; };
|
||||||
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
|
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
|
||||||
@@ -218,6 +219,7 @@
|
|||||||
|
|
||||||
/* Begin PBXFileReference section */
|
/* Begin PBXFileReference section */
|
||||||
12305EAE2B9D864400C92FEE /* PredictionView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PredictionView.swift; sourceTree = "<group>"; };
|
12305EAE2B9D864400C92FEE /* PredictionView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PredictionView.swift; sourceTree = "<group>"; };
|
||||||
|
1C5531792C5AAB4E00B07ECD /* Gemma2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Gemma2.swift; sourceTree = "<group>"; };
|
||||||
1CD79C6F2BD80DE100B6C06F /* Phi3.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Phi3.swift; sourceTree = "<group>"; };
|
1CD79C6F2BD80DE100B6C06F /* Phi3.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Phi3.swift; sourceTree = "<group>"; };
|
||||||
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = "<group>"; };
|
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = "<group>"; };
|
||||||
52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; };
|
52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; };
|
||||||
@@ -480,6 +482,7 @@
|
|||||||
C3A8B3AB2B9283150002EFB8 /* Models.swift */,
|
C3A8B3AB2B9283150002EFB8 /* Models.swift */,
|
||||||
C34E48EE2B696E6500FCB841 /* Llama.swift */,
|
C34E48EE2B696E6500FCB841 /* Llama.swift */,
|
||||||
C38935E22B86C0FE0037B833 /* Gemma.swift */,
|
C38935E22B86C0FE0037B833 /* Gemma.swift */,
|
||||||
|
1C5531792C5AAB4E00B07ECD /* Gemma2.swift */,
|
||||||
C38935C72B869C7A0037B833 /* LLM.h */,
|
C38935C72B869C7A0037B833 /* LLM.h */,
|
||||||
C38935E02B869F420037B833 /* LLMModel.swift */,
|
C38935E02B869F420037B833 /* LLMModel.swift */,
|
||||||
C38935DE2B869DD00037B833 /* Phi.swift */,
|
C38935DE2B869DD00037B833 /* Phi.swift */,
|
||||||
@@ -1014,6 +1017,7 @@
|
|||||||
C38935CE2B869C870037B833 /* Load.swift in Sources */,
|
C38935CE2B869C870037B833 /* Load.swift in Sources */,
|
||||||
C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */,
|
C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */,
|
||||||
C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */,
|
C3A8B3AC2B9283150002EFB8 /* Models.swift in Sources */,
|
||||||
|
1C55317A2C5AAB4E00B07ECD /* Gemma2.swift in Sources */,
|
||||||
C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */,
|
C3E786AB2B8D1AEC0004D037 /* Evaluate.swift in Sources */,
|
||||||
C38935CC2B869C870037B833 /* Llama.swift in Sources */,
|
C38935CC2B869C870037B833 /* Llama.swift in Sources */,
|
||||||
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */,
|
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */,
|
||||||
|
|||||||
Reference in New Issue
Block a user