feat: add command r model support (#35)
* feat: add command r model support
This commit is contained in:
238
Libraries/LLM/Cohere.swift
Normal file
238
Libraries/LLM/Cohere.swift
Normal file
@@ -0,0 +1,238 @@
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXFast
|
||||
import MLXNN
|
||||
|
||||
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/cohere.py
|
||||
|
||||
private class Attention: Module {
|
||||
|
||||
let args: CohereConfiguration
|
||||
let scale: Float
|
||||
|
||||
@ModuleInfo(key: "q_proj") var wq: Linear
|
||||
@ModuleInfo(key: "k_proj") var wk: Linear
|
||||
@ModuleInfo(key: "v_proj") var wv: Linear
|
||||
@ModuleInfo(key: "o_proj") var wo: Linear
|
||||
|
||||
let rope: RoPE
|
||||
|
||||
public init(_ args: CohereConfiguration) {
|
||||
self.args = args
|
||||
|
||||
let dim = args.hiddenSize
|
||||
let heads = args.attentionHeads
|
||||
let kvHeads = args.kvHeads
|
||||
|
||||
let headDim = args.hiddenSize / heads
|
||||
self.scale = pow(Float(headDim), -0.5)
|
||||
|
||||
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: false)
|
||||
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
||||
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: false)
|
||||
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: false)
|
||||
|
||||
self.rope = RoPE(
|
||||
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
|
||||
}
|
||||
|
||||
public func callAsFunction(
|
||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||
let (B, L) = (x.dim(0), x.dim(1))
|
||||
|
||||
var queries = wq(x)
|
||||
var keys = wk(x)
|
||||
var values = wv(x)
|
||||
|
||||
// prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
|
||||
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
|
||||
|
||||
if let (keyCache, valueCache) = cache {
|
||||
queries = rope(queries, offset: keyCache.dim(2))
|
||||
keys = rope(keys, offset: keyCache.dim(2))
|
||||
keys = concatenated([keyCache, keys], axis: 2)
|
||||
values = concatenated([valueCache, values], axis: 2)
|
||||
} else {
|
||||
queries = rope(queries)
|
||||
keys = rope(keys)
|
||||
}
|
||||
|
||||
let output = MLXFast.scaledDotProductAttention(
|
||||
queries: queries, keys: keys, values: values, scale: scale, mask: mask
|
||||
)
|
||||
.transposed(0, 2, 1, 3)
|
||||
.reshaped(B, L, -1)
|
||||
|
||||
return (wo(output), (keys, values))
|
||||
}
|
||||
}
|
||||
|
||||
private class MLP: Module, UnaryLayer {
|
||||
|
||||
@ModuleInfo(key: "gate_proj") var gate: Linear
|
||||
@ModuleInfo(key: "down_proj") var down: Linear
|
||||
@ModuleInfo(key: "up_proj") var up: Linear
|
||||
|
||||
public init(dimensions: Int, hiddenDimensions: Int) {
|
||||
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||
self._up.wrappedValue = Linear(dimensions, hiddenDimensions, bias: false)
|
||||
self._down.wrappedValue = Linear(hiddenDimensions, dimensions, bias: false)
|
||||
|
||||
}
|
||||
|
||||
public func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
down(silu(gate(x)) * up(x))
|
||||
}
|
||||
}
|
||||
|
||||
private class TransformerBlock: Module {
|
||||
|
||||
@ModuleInfo(key: "self_attn") var attention: Attention
|
||||
let mlp: MLP
|
||||
|
||||
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: LayerNorm
|
||||
|
||||
public init(_ args: CohereConfiguration) {
|
||||
self._attention.wrappedValue = Attention(args)
|
||||
self.mlp = MLP(dimensions: args.hiddenSize, hiddenDimensions: args.intermediateSize)
|
||||
self._inputLayerNorm.wrappedValue = LayerNorm(
|
||||
dimensions: args.hiddenSize, eps: args.layerNormEps)
|
||||
|
||||
}
|
||||
|
||||
public func callAsFunction(
|
||||
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
|
||||
) -> (MLXArray, (MLXArray, MLXArray)) {
|
||||
let h = inputLayerNorm(x)
|
||||
let (attnH, cache) = attention(h, mask: mask, cache: cache)
|
||||
let ffH = mlp(h)
|
||||
return (attnH + ffH + x, cache)
|
||||
}
|
||||
}
|
||||
|
||||
public class CohereModelInner: Module {
|
||||
|
||||
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
|
||||
|
||||
fileprivate let layers: [TransformerBlock]
|
||||
let norm: LayerNorm
|
||||
|
||||
public init(_ args: CohereConfiguration) {
|
||||
precondition(args.vocabularySize > 0)
|
||||
|
||||
self._embedTokens.wrappedValue = Embedding(
|
||||
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
|
||||
|
||||
self.layers = (0 ..< args.hiddenLayers)
|
||||
.map { _ in
|
||||
TransformerBlock(args)
|
||||
}
|
||||
self.norm = LayerNorm(dimensions: args.hiddenSize, eps: args.layerNormEps)
|
||||
}
|
||||
|
||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
|
||||
MLXArray, [(MLXArray, MLXArray)]
|
||||
) {
|
||||
var h = embedTokens(inputs)
|
||||
|
||||
var mask: MLXArray? = nil
|
||||
if h.dim(1) > 1 {
|
||||
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
|
||||
mask = mask?.asType(h.dtype)
|
||||
}
|
||||
|
||||
var newCache = [(MLXArray, MLXArray)]()
|
||||
|
||||
for (i, layer) in layers.enumerated() {
|
||||
var cacheUpdate: (MLXArray, MLXArray)
|
||||
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
|
||||
newCache.append(cacheUpdate)
|
||||
}
|
||||
|
||||
return (norm(h), newCache)
|
||||
}
|
||||
}
|
||||
|
||||
public class CohereModel: Module, LLMModel {
|
||||
|
||||
public let vocabularySize: Int
|
||||
let model: CohereModelInner
|
||||
let logitScale: Float
|
||||
|
||||
public init(_ args: CohereConfiguration) {
|
||||
self.vocabularySize = args.vocabularySize
|
||||
self.model = CohereModelInner(args)
|
||||
self.logitScale = args.logitScale
|
||||
}
|
||||
|
||||
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
|
||||
MLXArray, [(MLXArray, MLXArray)]
|
||||
) {
|
||||
var (out, cache) = model(inputs, cache: cache)
|
||||
out = matmul(out, model.embedTokens.weight.T)
|
||||
out = out * self.logitScale
|
||||
return (out, cache)
|
||||
}
|
||||
}
|
||||
|
||||
public struct CohereConfiguration: Codable {
|
||||
|
||||
var hiddenSize: Int
|
||||
var hiddenLayers: Int
|
||||
var intermediateSize: Int
|
||||
var attentionHeads: Int
|
||||
var layerNormEps: Float
|
||||
var vocabularySize: Int
|
||||
var kvHeads: Int
|
||||
var ropeTheta: Float = 8000000.0
|
||||
var ropeTraditional: Bool = true
|
||||
var ropeScaling: [String: StringOrNumber]? = nil
|
||||
var logitScale: Float
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case hiddenSize = "hidden_size"
|
||||
case hiddenLayers = "num_hidden_layers"
|
||||
case intermediateSize = "intermediate_size"
|
||||
case attentionHeads = "num_attention_heads"
|
||||
case kvHeads = "num_key_value_heads"
|
||||
case ropeTheta = "rope_theta"
|
||||
case vocabularySize = "vocab_size"
|
||||
case layerNormEps = "layer_norm_eps"
|
||||
case logitScale = "logit_scale"
|
||||
case ropeTraditional = "rope_traditional"
|
||||
case ropeScaling = "rope_scaling"
|
||||
}
|
||||
|
||||
public init(from decoder: Decoder) throws {
|
||||
// custom implementation to handle optional keys with required values
|
||||
let container: KeyedDecodingContainer<CohereConfiguration.CodingKeys> =
|
||||
try decoder.container(
|
||||
keyedBy: CohereConfiguration.CodingKeys.self)
|
||||
|
||||
self.hiddenSize = try container.decode(
|
||||
Int.self, forKey: CohereConfiguration.CodingKeys.hiddenSize)
|
||||
self.hiddenLayers = try container.decode(
|
||||
Int.self, forKey: CohereConfiguration.CodingKeys.hiddenLayers)
|
||||
self.intermediateSize = try container.decode(
|
||||
Int.self, forKey: CohereConfiguration.CodingKeys.intermediateSize)
|
||||
self.attentionHeads = try container.decode(
|
||||
Int.self, forKey: CohereConfiguration.CodingKeys.attentionHeads)
|
||||
self.layerNormEps = try container.decode(
|
||||
Float.self, forKey: CohereConfiguration.CodingKeys.layerNormEps)
|
||||
self.vocabularySize = try container.decode(
|
||||
Int.self, forKey: CohereConfiguration.CodingKeys.vocabularySize)
|
||||
self.kvHeads = try container.decode(
|
||||
Int.self, forKey: CohereConfiguration.CodingKeys.kvHeads)
|
||||
self.ropeTheta =
|
||||
try container.decodeIfPresent(
|
||||
Float.self, forKey: CohereConfiguration.CodingKeys.ropeTheta)
|
||||
?? 8000000.0
|
||||
self.ropeScaling = try container.decodeIfPresent(
|
||||
[String: StringOrNumber].self, forKey: CohereConfiguration.CodingKeys.ropeScaling)
|
||||
self.logitScale = try container.decode(
|
||||
Float.self, forKey: CohereConfiguration.CodingKeys.logitScale)
|
||||
}
|
||||
}
|
||||
@@ -33,6 +33,7 @@ public enum ModelType: String, Codable {
|
||||
case gemma
|
||||
case qwen2
|
||||
case starcoder2
|
||||
case cohere
|
||||
|
||||
func createModel(configuration: URL) throws -> LLMModel {
|
||||
switch self {
|
||||
@@ -56,6 +57,10 @@ public enum ModelType: String, Codable {
|
||||
let configuration = try JSONDecoder().decode(
|
||||
Starcoder2Configuration.self, from: Data(contentsOf: configuration))
|
||||
return Starcoder2Model(configuration)
|
||||
case .cohere:
|
||||
let configuration = try JSONDecoder().decode(
|
||||
CohereConfiguration.self, from: Data(contentsOf: configuration))
|
||||
return CohereModel(configuration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,5 +28,6 @@ public func loadTokenizer(configuration: ModelConfiguration) async throws -> Tok
|
||||
|
||||
/// overrides for TokenizerModel/knownTokenizers
|
||||
let replacementTokenizers = [
|
||||
"Qwen2Tokenizer": "PreTrainedTokenizer"
|
||||
"Qwen2Tokenizer": "PreTrainedTokenizer",
|
||||
"CohereTokenizer": "PreTrainedTokenizer",
|
||||
]
|
||||
|
||||
@@ -68,6 +68,7 @@
|
||||
C3FBCB312B8520F20007E490 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB302B8520F20007E490 /* MLXNN */; };
|
||||
C3FBCB332B8520F20007E490 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB322B8520F20007E490 /* MLXOptimizers */; };
|
||||
C3FBCB352B8520F20007E490 /* MLXRandom in Frameworks */ = {isa = PBXBuildFile; productRef = C3FBCB342B8520F20007E490 /* MLXRandom */; };
|
||||
F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */ = {isa = PBXBuildFile; fileRef = F24B08392BAF1A65008C8D19 /* Cohere.swift */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXContainerItemProxy section */
|
||||
@@ -234,6 +235,7 @@
|
||||
C3C3240C2B6CA792007D2D9A /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
|
||||
C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Evaluate.swift; sourceTree = "<group>"; };
|
||||
C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Tokenizer.swift; sourceTree = "<group>"; };
|
||||
F24B08392BAF1A65008C8D19 /* Cohere.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Cohere.swift; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
@@ -383,6 +385,7 @@
|
||||
C3E786AA2B8D1AEC0004D037 /* Evaluate.swift */,
|
||||
C3E786AC2B8D4AF50004D037 /* Tokenizer.swift */,
|
||||
52A776172B94B5EE00AA6E80 /* Qwen2.swift */,
|
||||
F24B08392BAF1A65008C8D19 /* Cohere.swift */,
|
||||
);
|
||||
path = LLM;
|
||||
sourceTree = "<group>";
|
||||
@@ -847,6 +850,7 @@
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
C38935E12B869F420037B833 /* LLMModel.swift in Sources */,
|
||||
F24B083A2BAF1A65008C8D19 /* Cohere.swift in Sources */,
|
||||
C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */,
|
||||
C38935CD2B869C870037B833 /* Configuration.swift in Sources */,
|
||||
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */,
|
||||
|
||||
@@ -56,9 +56,13 @@
|
||||
isEnabled = "NO">
|
||||
</CommandLineArgument>
|
||||
<CommandLineArgument
|
||||
argument = "--model mlx-community/starcoder2-3b-4bit"
|
||||
argument = "--model mlx-community/c4ai-command-r-v01-4bit"
|
||||
isEnabled = "YES">
|
||||
</CommandLineArgument>
|
||||
<CommandLineArgument
|
||||
argument = "--model mlx-community/starcoder2-3b-4bit"
|
||||
isEnabled = "NO">
|
||||
</CommandLineArgument>
|
||||
<CommandLineArgument
|
||||
argument = "--model mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
isEnabled = "NO">
|
||||
|
||||
Reference in New Issue
Block a user