handle partially quantized models (#76)

* handle partially quantized models

- fix for #53 #71 #69 #74
- in order to test the models
	- I added a default prompt of an appropriate form
	- while working on the model configuration also added additional stop tokens (#74)
- fixed the repetitionPenalty code (#71)
This commit is contained in:
David Koski
2024-05-28 16:35:11 -07:00
committed by GitHub
parent 65f4968e5f
commit 9d74afd119
12 changed files with 139 additions and 67 deletions

View File

@@ -10,7 +10,7 @@ import Tokenizers
struct ContentView: View {
@State var prompt = "compare python and swift"
@State var prompt = ""
@State var llm = LLMEvaluator()
@Environment(DeviceStat.self) private var deviceStat
@@ -125,6 +125,8 @@ struct ContentView: View {
}
.task {
self.prompt = llm.modelConfiguration.defaultPrompt
// pre-load the weights on launch to speed up the first generation
_ = try? await llm.load()
}
@@ -224,7 +226,7 @@ class LLMEvaluator {
let result = await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer
tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
) { tokens in
// update the output -- this will make the view show the text as it generates
if tokens.count % displayEveryNTokens == 0 {

View File

@@ -266,6 +266,7 @@ class LoRAEvaluator {
let result = await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer,
extraEOSTokens: modelConfiguration.extraEOSTokens,
didGenerate: { tokens in
if tokens.count % evaluateShowEvery == 0 {
let fullOutput = tokenizer.decode(tokens: tokens)

View File

@@ -12,7 +12,7 @@ private func topPSampling(logits: MLXArray, topP: Float, temp: Float) -> MLXArra
logits = logits.asType(.float32)
}
let probs = softMax(logits / temp, axis: -1)
let probs = softmax(logits / temp, axis: -1)
let sortedIndices = argSort(probs, axis: -1)
// probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V]
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
) -> MLXArray {
if repetitionContext.shape[0] > 0 {
let indices = repetitionContext
var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
var selectedLogits = logits[0..., indices]
selectedLogits = MLX.where(
selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
if prompt.shape[0] <= parameters.repetitionContextSize {
self.repetitionContext = prompt
} else {
self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
}
} else {
self.repetitionContext = []
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
if parameters.repetitionContextSize > 1 {
repetitionContext = concatenated([repetitionContext, y], axis: 0)
if repetitionContext.shape[0] > parameters.repetitionContextSize {
repetitionContext = repetitionContext[1...]
repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
}
}
@@ -174,14 +173,31 @@ public enum GenerateDisposition {
/// - parameters: generation parameters
/// - model: model to evaluate
/// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens
/// - configuration: the model configuration
/// - didGenerate: visitor for the tokens as they are generated
public func generate(
promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer,
extraEOSTokens: Set<String>? = nil,
didGenerate: ([Int]) async -> GenerateDisposition
) async -> GenerateResult {
var start = Date.timeIntervalSinceReferenceDate
var promptTime: TimeInterval = 0
// build a set of additional stop tokens
let additionalEOSTokenIds = Set(
(extraEOSTokens ?? [])
.map {
tokenizer.encode(text: $0)
}
.filter {
// discard anything that is not a single token. sometimes
// the tokenizer will insert a <s> token, so accept that too
$0.count == 1 || ($0.count == 2 && $0[0] == 1)
}
.map {
$0.last!
})
var tokens = [Int]()
for token in TokenIterator(
@@ -196,7 +212,9 @@ public func generate(
}
let t = token.item(Int.self)
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId {
if t == tokenizer.unknownTokenId || t == tokenizer.eosTokenId
|| additionalEOSTokenIds.contains(t)
{
break
}

View File

@@ -12,4 +12,15 @@ public protocol LLMModel: Module {
func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
)
/// Optionally preprocess the weights and modify / remove values as needed.
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
}
extension LLMModel {
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
weights
}
}

View File

@@ -194,6 +194,13 @@ public class LlamaModel: Module, LLMModel {
let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache)
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
// Remove unused precomputed rotary freqs
weights.filter {
!$0.key.contains("self_attn.rotary_emb.inv_freq")
}
}
}
public struct LlamaConfiguration: Codable {

View File

@@ -54,9 +54,15 @@ public func load(
}
}
// per-model cleanup
weights = model.sanitize(weights: weights)
// quantize if needed
if let quantization = baseConfig.quantization {
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
path, module in
weights["\(path).scales"] != nil
}
}
// apply the loaded weights
@@ -76,38 +82,3 @@ public func load(
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
}
}
// MARK: - Quantization
private func quantizeIfNeeded(
model: LLMModel, weights: [String: MLXArray], quantization: BaseConfiguration.Quantization
) {
func linearPredicate(layer: Module) -> Bool {
if let layer = layer as? Linear {
// avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
return layer.weight.dim(0) != 8
}
return false
}
var predicate = linearPredicate(layer:)
// for legacy models that don't have lm_head quant due to non-32 dims
if weights["lm_head.scales"] == nil {
let vocabularySize = model.vocabularySize
func vocabularySizePredicate(layer: Module) -> Bool {
if let layer = layer as? Linear {
return layer.weight.dim(0) != 8 && layer.weight.dim(0) != vocabularySize
}
return false
}
predicate = vocabularySizePredicate(layer:)
}
QuantizedLinear.quantize(
model: model, groupSize: quantization.groupSize, bits: quantization.bits,
predicate: predicate)
}

View File

@@ -377,7 +377,7 @@ public enum LoRATrain {
/// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)``
/// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``
/// - fusing with ``fuse(model:layers:deQuantize:)``
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:didGenerate:)``
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:additionalEOSTokens:didGenerate:)``
/// - note that this is just using normal model text generation
///
/// - Parameters:

View File

@@ -33,6 +33,12 @@ public struct ModelConfiguration {
/// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated
public let overrideTokenizer: String?
/// A reasonable default prompt for the model
public let defaultPrompt: String
/// Additional tokens to use for end of string
public let extraEOSTokens: Set<String>
/// custom preparation logic for the prompt. custom tokenizers provide more capability, but this
/// allows some minor formtting changes, e.g. wrapping the user input in the expected prompt
/// format
@@ -40,21 +46,29 @@ public struct ModelConfiguration {
public init(
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
preparePrompt: ((String) -> String)? = nil
) {
self.id = .id(id)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt
}
public init(
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
defaultPrompt: String = "hello",
extraEOSTokens: Set<String> = [],
preparePrompt: ((String) -> String)? = nil
) {
self.id = .directory(directory)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.defaultPrompt = defaultPrompt
self.extraEOSTokens = extraEOSTokens
self.preparePrompt = preparePrompt
}
@@ -98,11 +112,16 @@ public struct ModelConfiguration {
extension ModelConfiguration {
public static let mistral7B4bit = ModelConfiguration(
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx")
id: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
// https://www.promptingguide.ai/models/mistral-7b
defaultPrompt: "describe the swift language"
)
public static let codeLlama13b4bit = ModelConfiguration(
id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX",
overrideTokenizer: "PreTrainedTokenizer"
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "func sortArray(_ array: [Int]) -> String { <FILL_ME> }"
) { prompt in
// given the prompt: func sortArray(_ array: [Int]) -> String { <FILL_ME> }
// the python code produces this (via its custom tokenizer):
@@ -111,13 +130,17 @@ extension ModelConfiguration {
"<PRE> " + prompt.replacingOccurrences(of: "<FILL_ME>", with: "<SUF>") + " <MID>"
}
public static let phi4bit = ModelConfiguration(id: "mlx-community/phi-2-hf-4bit-mlx") {
prompt in
"Instruct: \(prompt)\nOutput: "
}
public static let phi4bit = ModelConfiguration(
id: "mlx-community/phi-2-hf-4bit-mlx",
// https://www.promptingguide.ai/models/phi-2
defaultPrompt: "Why is the sky blue?"
)
public static let phi34bit = ModelConfiguration(
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed"
id: "mlx-community/Phi-3-mini-4k-instruct-4bit-no-q-embed",
defaultPrompt: "what is the gravity on mars and the moon?",
extraEOSTokens: ["<|end|>"]
) {
prompt in
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
@@ -125,26 +148,35 @@ extension ModelConfiguration {
public static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer"
overrideTokenizer: "PreTrainedTokenizer",
// https://www.promptingguide.ai/models/gemma
defaultPrompt: "what is the difference between lettuce and cabbage?"
) { prompt in
"<start_of_turn>user \(prompt)<end_of_turn><start_of_turn>model"
}
public static let qwen205b4bit = ModelConfiguration(
id: "mlx-community/Qwen1.5-0.5B-Chat-4bit",
overrideTokenizer: "PreTrainedTokenizer"
overrideTokenizer: "PreTrainedTokenizer",
defaultPrompt: "why is the sky blue?"
) { prompt in
"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n\(prompt)<|im_end|>\n<|im_start|>assistant"
}
public static let openelm270m4bit = ModelConfiguration(
id: "mlx-community/OpenELM-270M-Instruct"
id: "mlx-community/OpenELM-270M-Instruct",
// https://huggingface.co/apple/OpenELM
defaultPrompt: "Once upon a time there was"
) { prompt in
"\(prompt)"
}
public static let llama38B4bit = ModelConfiguration(
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
defaultPrompt: "what is the difference between a fruit and a vegetable?"
) {
prompt in
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\(prompt)<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>"

View File

@@ -179,10 +179,12 @@ public class Qwen2ModelInner: Module {
public class Qwen2Model: Module, LLMModel {
public let vocabularySize: Int
let model: Qwen2ModelInner
let configuration: Qwen2Configuration
@ModuleInfo(key: "lm_head") var lmHead: Linear
public init(_ args: Qwen2Configuration) {
self.configuration = args
self.vocabularySize = args.vocabularySize
self.model = Qwen2ModelInner(args)
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
@@ -191,8 +193,26 @@ public class Qwen2Model: Module, LLMModel {
public func callAsFunction(_ inputs: MLXArray, cache: [(MLXArray, MLXArray)]?) -> (
MLXArray, [(MLXArray, MLXArray)]
) {
let (out, cache) = model(inputs, cache: cache)
return (lmHead(out), cache)
var (out, cache) = model(inputs, cache: cache)
if configuration.tieWordEmbeddings {
out = model.embedTokens.asLinear(out)
} else {
out = lmHead(out)
}
return (out, cache)
}
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var weights = weights
if configuration.tieWordEmbeddings {
weights["lm_head.weight"] = nil
}
// Remove unused precomputed rotary freqs
return weights.filter {
!$0.key.contains("self_attn.rotary_emb.inv_freq")
}
}
}
@@ -207,6 +227,7 @@ public struct Qwen2Configuration: Codable {
var ropeTheta: Float = 1_000_000
var ropeTraditional: Bool = false
var ropeScaling: [String: StringOrNumber]? = nil
var tieWordEmbeddings = false
enum CodingKeys: String, CodingKey {
case hiddenSize = "hidden_size"
@@ -219,6 +240,7 @@ public struct Qwen2Configuration: Codable {
case ropeTheta = "rope_theta"
case ropeTraditional = "rope_traditional"
case ropeScaling = "rope_scaling"
case tieWordEmbeddings = "tie_word_embeddings"
}
public init(from decoder: Decoder) throws {
@@ -249,6 +271,8 @@ public struct Qwen2Configuration: Codable {
Bool.self, forKey: Qwen2Configuration.CodingKeys.ropeTraditional) ?? false
self.ropeScaling = try container.decodeIfPresent(
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
self.tieWordEmbeddings =
try container.decodeIfPresent(Bool.self, forKey: .tieWordEmbeddings) ?? false
}
}

View File

@@ -44,7 +44,7 @@ struct GenerateArguments: ParsableArguments {
help:
"The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt"
)
var prompt = "compare python and swift"
var prompt: String?
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
var maxTokens = 100
@@ -73,7 +73,8 @@ struct GenerateArguments: ParsableArguments {
repetitionContextSize: repetitionContextSize)
}
func resolvePrompt() throws -> String {
func resolvePrompt(configuration: ModelConfiguration) throws -> String {
let prompt = self.prompt ?? configuration.defaultPrompt
if prompt.hasPrefix("@") {
let names = prompt.split(separator: ",").map { String($0.dropFirst()) }
return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n")
@@ -87,14 +88,17 @@ struct GenerateArguments: ParsableArguments {
) {
MLXRandom.seed(seed)
let prompt = try resolvePrompt()
let prompt = try resolvePrompt(configuration: configuration)
let preparedPrompt = configuration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt)
return (prompt, promptTokens)
}
func generate(promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer) async
func generate(
promptTokens: [Int], model: LLMModel, tokenizer: Tokenizer,
extraEOSTokens: Set<String>? = nil
) async
-> GenerateResult
{
// track how much we have printed
@@ -102,7 +106,7 @@ struct GenerateArguments: ParsableArguments {
return await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters,
model: model, tokenizer: tokenizer
model: model, tokenizer: tokenizer, extraEOSTokens: extraEOSTokens
) { tokens in
// print any new parts of the string
@@ -226,7 +230,8 @@ struct EvaluateCommand: AsyncParsableCommand {
}
let result = await generate.generate(
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
promptTokens: promptTokens, model: model, tokenizer: tokenizer,
extraEOSTokens: modelConfiguration.extraEOSTokens)
print()
if !generate.quiet {

View File

@@ -275,7 +275,8 @@ struct LoRAEvalCommand: AsyncParsableCommand {
// generate and print the result
let _ = await generate.generate(
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
promptTokens: promptTokens, model: model, tokenizer: tokenizer,
extraEOSTokens: modelConfiguration.extraEOSTokens)
print()
}
}

View File

@@ -16,7 +16,7 @@
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"branch" : "main",
"revision" : "3c802c808d281c191d5f26f37a4f93135d8ca119"
"revision" : "d6d9472da5bf7ec2654e8914bd1d15622f45b6a9"
}
},
{
@@ -61,7 +61,7 @@
"location" : "https://github.com/gonzalezreal/swift-markdown-ui",
"state" : {
"branch" : "main",
"revision" : "723249a1ba361042812cf785244de94f11f7c8fd"
"revision" : "c0daf6eb79d97964180f3113868c990bd1c4a007"
}
},
{