committed by
GitHub
parent
a2e8d7e469
commit
c4fda0e036
@@ -26,7 +26,7 @@ private class Attention: Module {
|
||||
let heads = args.attentionHeads
|
||||
let kvHeads = args.kvHeads
|
||||
|
||||
let headDim = args.hiddenSize / heads
|
||||
let headDim = args.headDimensions ?? (args.hiddenSize / heads)
|
||||
self.scale = pow(Float(headDim), -0.5)
|
||||
|
||||
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
|
||||
@@ -215,6 +215,7 @@ public struct LlamaConfiguration: Codable {
|
||||
var hiddenLayers: Int
|
||||
var intermediateSize: Int
|
||||
var attentionHeads: Int
|
||||
var headDimensions: Int? = nil
|
||||
var rmsNormEps: Float
|
||||
var vocabularySize: Int
|
||||
var kvHeads: Int
|
||||
@@ -228,6 +229,7 @@ public struct LlamaConfiguration: Codable {
|
||||
case hiddenLayers = "num_hidden_layers"
|
||||
case intermediateSize = "intermediate_size"
|
||||
case attentionHeads = "num_attention_heads"
|
||||
case headDimensions = "head_dim"
|
||||
case rmsNormEps = "rms_norm_eps"
|
||||
case vocabularySize = "vocab_size"
|
||||
case kvHeads = "num_key_value_heads"
|
||||
@@ -251,6 +253,8 @@ public struct LlamaConfiguration: Codable {
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.intermediateSize)
|
||||
self.attentionHeads = try container.decode(
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.attentionHeads)
|
||||
self.headDimensions = try container.decodeIfPresent(
|
||||
Int.self, forKey: LlamaConfiguration.CodingKeys.headDimensions)
|
||||
self.rmsNormEps = try container.decode(
|
||||
Float.self, forKey: LlamaConfiguration.CodingKeys.rmsNormEps)
|
||||
self.vocabularySize = try container.decode(
|
||||
|
||||
@@ -118,12 +118,19 @@ extension ModelConfiguration {
|
||||
"<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant\n"
|
||||
}
|
||||
|
||||
public static let mistral7B4bit = ModelConfiguration(
|
||||
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
|
||||
public static let mistralNeMo4bit = ModelConfiguration(
|
||||
id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit",
|
||||
defaultPrompt: "Explain quaternions."
|
||||
) { prompt in
|
||||
"<s>[INST] \(prompt) [/INST] "
|
||||
}
|
||||
|
||||
// https://www.promptingguide.ai/models/mistral-7b
|
||||
defaultPrompt: "describe the swift language"
|
||||
)
|
||||
public static let mistral7B4bit = ModelConfiguration(
|
||||
id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
|
||||
defaultPrompt: "Describe the Swift language."
|
||||
) { prompt in
|
||||
"<s>[INST] \(prompt) [/INST] "
|
||||
}
|
||||
|
||||
public static let codeLlama13b4bit = ModelConfiguration(
|
||||
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
|
||||
@@ -213,6 +220,8 @@ extension ModelConfiguration {
|
||||
case .idle:
|
||||
bootstrapState = .bootstrapping
|
||||
register(configurations: [
|
||||
mistralNeMo4bit,
|
||||
smolLM_135M_4bit,
|
||||
mistral7B4bit,
|
||||
codeLlama13b4bit,
|
||||
phi4bit,
|
||||
|
||||
Reference in New Issue
Block a user