Files
Anthony ac6bdfccec Add Llama 3.1 (#98)
* Update Mistral 7B config

* Add Mistral NeMo

* Update for Llama 3.1

* Align LlamaConfiguration with Python implementation

* Fix model configuration names

* Refine DynamicNTKScalingRoPE

* compute base only once

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-07-26 13:05:42 -07:00

392 lines
14 KiB
Swift

// Copyright © 2024 Apple Inc.
import Foundation
import MLX
import MLXFast
import MLXNN
// port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py
func computeBaseFrequency(
base: Float, dims: Int, ropeType: String, ropeScaling: [String: StringOrNumber]?
)
-> Float
{
if ropeType != "llama3" {
return base
}
guard let ropeScaling = ropeScaling else {
return base
}
guard case .float(let factor) = ropeScaling["factor"],
case .float(let lowFreqFactor) = ropeScaling["low_freq_factor"] ?? .float(1.0),
case .float(let highFreqFactor) = ropeScaling["high_freq_factor"] ?? .float(4.0),
case .float(let oldContextLen) = ropeScaling["original_max_position_embeddings"]
?? .float(8192)
else {
return base
}
let lowFreqWavelen = oldContextLen / lowFreqFactor
let highFreqWavelen = oldContextLen / highFreqFactor
let freqs = (0 ..< dims).compactMap { index -> Float? in
if index % 2 == 0 {
return pow(base, Float(index) / Float(dims))
}
return nil
}
let newBaseFreqs = freqs.map { freq -> Float in
let wavelen = 2 * .pi / freq
let smooth = max(
0, min(1, (wavelen - highFreqWavelen) / (lowFreqWavelen - highFreqWavelen)))
return freq * ((1 - smooth) * factor + smooth)
}
return newBaseFreqs.reduce(0, +) / Float(newBaseFreqs.count)
}
private class DynamicNTKScalingRoPE: Module {
let dims: Int
let maxPositionEmbeddings: Int?
let traditional: Bool
let base: Float
var scale: Float
let ropeType: String
let ropeScaling: [String: StringOrNumber]?
init(
dims: Int, maxPositionEmbeddings: Int?, traditional: Bool = false,
base: Float = 10000, scale: Float = 1.0, ropeType: String = "default",
ropeScaling: [String: StringOrNumber]? = nil
) {
self.dims = dims
self.maxPositionEmbeddings = maxPositionEmbeddings
self.traditional = traditional
self.base = computeBaseFrequency(
base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling)
self.scale = scale
self.ropeType = ropeType
self.ropeScaling = ropeScaling
}
func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray {
let seqLen = x.dim(1) + offset
var base = self.base
if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings {
let factorAdjustment = Float(seqLen) / Float(maxPositionEmbeddings) - 1
let dimensionRatio = Float(dims) / Float(Float(dims) - 2)
let adjustedScale = scale * pow(1 + factorAdjustment, dimensionRatio)
base *= adjustedScale
}
return MLXFast.RoPE(
x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset)
}
}
private class Attention: Module {
let args: LlamaConfiguration
let scale: Float
@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear
let rope: DynamicNTKScalingRoPE
init(_ args: LlamaConfiguration) {
self.args = args
let dim = args.hiddenSize
let heads = args.attentionHeads
let kvHeads = args.kvHeads
let headDim = args.headDimensions ?? (args.hiddenSize / heads)
self.scale = pow(Float(headDim), -0.5)
self._wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias)
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias)
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias)
self.rope = DynamicNTKScalingRoPE(
dims: headDim,
maxPositionEmbeddings: args.maxPositionEmbeddings,
traditional: args.ropeTraditional,
base: args.ropeTheta,
scale: 1.0,
ropeType: {
if case .string(let value) = args.ropeScaling?["type"] {
return value
} else {
return "default"
}
}(),
ropeScaling: args.ropeScaling)
}
func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
let (B, L) = (x.dim(0), x.dim(1))
var queries = wq(x)
var keys = wk(x)
var values = wv(x)
// Prepare the queries, keys and values for the attention computation
queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
if let (keyCache, valueCache) = cache {
queries = rope(queries, offset: keyCache.dim(2))
keys = rope(keys, offset: keyCache.dim(2))
keys = concatenated([keyCache, keys], axis: 2)
values = concatenated([valueCache, values], axis: 2)
} else {
queries = rope(queries)
keys = rope(keys)
}
let output = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)
return (wo(output), (keys, values))
}
}
private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "gate_proj") var gate: Linear
@ModuleInfo(key: "down_proj") var down: Linear
@ModuleInfo(key: "up_proj") var up: Linear
init(_ args: LlamaConfiguration) {
self._gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
self._down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias)
self._up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias)
}
func callAsFunction(_ x: MLXArray) -> MLXArray {
let activation = silu(gate(x))
return down(activation * up(x))
}
}
private class TransformerBlock: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
@ModuleInfo(key: "mlp") var mlp: MLP
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm
init(_ args: LlamaConfiguration) {
self._attention.wrappedValue = Attention(args)
self._mlp.wrappedValue = MLP(args)
self._inputLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayerNorm.wrappedValue = RMSNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
func callAsFunction(
_ x: MLXArray, mask: MLXArray? = nil, cache: (MLXArray, MLXArray)? = nil
) -> (MLXArray, (MLXArray, MLXArray)) {
var (r, cache) = attention(inputLayerNorm(x), mask: mask, cache: cache)
let h = x + r
r = mlp(postAttentionLayerNorm(h))
let out = h + r
return (out, cache)
}
}
private class LlamaModelInner: Module {
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
let layers: [TransformerBlock]
let norm: RMSNorm
init(_ args: LlamaConfiguration) {
precondition(args.vocabularySize > 0)
self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) }
self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]? = nil) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
var h = embedTokens(inputs)
var mask: MLXArray? = nil
if h.dim(1) > 1 {
mask = MultiHeadAttention.createAdditiveCausalMask(h.dim(1))
mask = mask?.asType(h.dtype)
}
var newCache = [(MLXArray, MLXArray)]()
for (i, layer) in layers.enumerated() {
var cacheUpdate: (MLXArray, MLXArray)
(h, cacheUpdate) = layer(h, mask: mask, cache: cache?[i])
newCache.append(cacheUpdate)
}
return (norm(h), newCache)
}
}
public class LlamaModel: Module, LLMModel {
public let vocabularySize: Int
fileprivate let model: LlamaModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear?
public init(_ args: LlamaConfiguration) {
self.vocabularySize = args.vocabularySize
self.model = LlamaModelInner(args)
if !args.tieWordEmbeddings {
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}
}
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
let (out, cache) = model(inputs, cache: cache)
if let lmHead {
return (lmHead(out), cache)
} else {
return (model.embedTokens.asLinear(out), cache)
}
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
// Remove unused precomputed rotary frequencies
weights.filter {
!$0.key.contains("self_attn.rotary_emb.inv_freq")
}
}
}
public struct LlamaConfiguration: Codable {
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
var attentionHeads: Int
var headDimensions: Int?
var rmsNormEps: Float
var vocabularySize: Int
var kvHeads: Int
var maxPositionEmbeddings: Int?
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]?
var tieWordEmbeddings: Bool = true
var attentionBias: Bool = false
var mlpBias: Bool = false
enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
case hiddenLayers = "num_hidden_layers"
case intermediateSize = "intermediate_size"
case attentionHeads = "num_attention_heads"
case headDimensions = "head_dim"
case rmsNormEps = "rms_norm_eps"
case vocabularySize = "vocab_size"
case kvHeads = "num_key_value_heads"
case maxPositionEmbeddings = "max_position_embeddings"
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling"
case tieWordEmbeddings = "tie_word_embeddings"
case attentionBias = "attention_bias"
case mlpBias = "mlp_bias"
}
public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
attentionHeads = try container.decode(Int.self, forKey: .attentionHeads)
headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions)
rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps)
vocabularySize = try container.decode(Int.self, forKey: .vocabularySize)
kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads
maxPositionEmbeddings = try container.decodeIfPresent(
Int.self, forKey: .maxPositionEmbeddings)
if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) {
self.ropeTheta = ropeTheta
}
if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional)
{
self.ropeTraditional = ropeTraditional
}
ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: .ropeScaling)
if let tieWordEmbeddings = try container.decodeIfPresent(
Bool.self, forKey: .tieWordEmbeddings)
{
self.tieWordEmbeddings = tieWordEmbeddings
}
if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) {
self.attentionBias = attentionBias
}
if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) {
self.mlpBias = mlpBias
}
if let ropeScaling {
if ropeScaling["factor"] == nil {
throw DecodingError.dataCorruptedError(
forKey: .ropeScaling, in: container,
debugDescription: "rope_scaling must contain 'factor'")
}
if let ropeType = ropeScaling["type"] ?? ropeScaling["rope_type"] {
if case .string = ropeType {
let options = [
StringOrNumber.string("linear"), StringOrNumber.string("dynamic"),
StringOrNumber.string("llama3"),
]
if !options.contains(ropeType) {
throw DecodingError.dataCorruptedError(
forKey: .ropeScaling, in: container,
debugDescription:
"rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3'"
)
}
}
} else {
throw DecodingError.dataCorruptedError(
forKey: .ropeScaling, in: container,
debugDescription: "rope_scaling must contain either 'type' or 'rope_type'")
}
}
}
}
// MARK: - LoRA
extension LlamaModel: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}