* implement LoRA / QLoRA - example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task - see also https://arxiv.org/abs/2106.09685 - based on https://github.com/ml-explore/mlx-examples/tree/main/lora * add some command line flags I found useful during use - --quiet -- don't print decorator text, just the generated text - --prompt @/tmp/file.txt -- load prompt from file * user can specify path to model OR model identifier in huggingface * update mlx-swift reference Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com> Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
261 lines
8.3 KiB
Swift
261 lines
8.3 KiB
Swift
// Copyright © 2024 Apple Inc.
|
|
|
|
import LLM
|
|
import MLX
|
|
import MLXRandom
|
|
import MarkdownUI
|
|
import Metal
|
|
import SwiftUI
|
|
import Tokenizers
|
|
|
|
struct ContentView: View {
|
|
|
|
@State var prompt = "compare python and swift"
|
|
@State var llm = LLMEvaluator()
|
|
@Environment(DeviceStat.self) private var deviceStat
|
|
|
|
enum displayStyle: String, CaseIterable, Identifiable {
|
|
case plain, markdown
|
|
var id: Self { self }
|
|
}
|
|
|
|
@State private var selectedDisplayStyle = displayStyle.markdown
|
|
|
|
var body: some View {
|
|
VStack(alignment: .leading) {
|
|
VStack {
|
|
HStack {
|
|
Text(llm.modelInfo)
|
|
.textFieldStyle(.roundedBorder)
|
|
|
|
Spacer()
|
|
|
|
Text(llm.stat)
|
|
}
|
|
HStack {
|
|
Spacer()
|
|
if llm.running {
|
|
ProgressView()
|
|
.frame(maxHeight: 20)
|
|
Spacer()
|
|
}
|
|
Picker("", selection: $selectedDisplayStyle) {
|
|
ForEach(displayStyle.allCases, id: \.self) { option in
|
|
Text(option.rawValue.capitalized)
|
|
.tag(option)
|
|
}
|
|
|
|
}
|
|
.pickerStyle(.segmented)
|
|
#if os(visionOS)
|
|
.frame(maxWidth: 250)
|
|
#else
|
|
.frame(maxWidth: 150)
|
|
#endif
|
|
}
|
|
}
|
|
|
|
// show the model output
|
|
ScrollView(.vertical) {
|
|
ScrollViewReader { sp in
|
|
Group {
|
|
if selectedDisplayStyle == .plain {
|
|
Text(llm.output)
|
|
.textSelection(.enabled)
|
|
} else {
|
|
Markdown(llm.output)
|
|
.textSelection(.enabled)
|
|
}
|
|
}
|
|
.onChange(of: llm.output) { _, _ in
|
|
sp.scrollTo("bottom")
|
|
}
|
|
|
|
Spacer()
|
|
.frame(width: 1, height: 1)
|
|
.id("bottom")
|
|
}
|
|
}
|
|
|
|
HStack {
|
|
TextField("prompt", text: $prompt)
|
|
.onSubmit(generate)
|
|
.disabled(llm.running)
|
|
#if os(visionOS)
|
|
.textFieldStyle(.roundedBorder)
|
|
#endif
|
|
Button("generate", action: generate)
|
|
.disabled(llm.running)
|
|
}
|
|
}
|
|
#if os(visionOS)
|
|
.padding(40)
|
|
#else
|
|
.padding()
|
|
#endif
|
|
.toolbar {
|
|
ToolbarItem {
|
|
Label(
|
|
"Memory Usage: \(deviceStat.gpuUsage.activeMemory.formatted(.byteCount(style: .memory)))",
|
|
systemImage: "info.circle.fill"
|
|
)
|
|
.labelStyle(.titleAndIcon)
|
|
.padding(.horizontal)
|
|
.help(
|
|
Text(
|
|
"""
|
|
Active Memory: \(deviceStat.gpuUsage.activeMemory.formatted(.byteCount(style: .memory)))/\(GPU.memoryLimit.formatted(.byteCount(style: .memory)))
|
|
Cache Memory: \(deviceStat.gpuUsage.cacheMemory.formatted(.byteCount(style: .memory)))/\(GPU.cacheLimit.formatted(.byteCount(style: .memory)))
|
|
Peak Memory: \(deviceStat.gpuUsage.peakMemory.formatted(.byteCount(style: .memory)))
|
|
"""
|
|
)
|
|
)
|
|
}
|
|
ToolbarItem(placement: .primaryAction) {
|
|
Button {
|
|
Task {
|
|
copyToClipboard(llm.output)
|
|
}
|
|
} label: {
|
|
Label("Copy Output", systemImage: "doc.on.doc.fill")
|
|
}
|
|
.disabled(llm.output == "")
|
|
.labelStyle(.titleAndIcon)
|
|
}
|
|
|
|
}
|
|
.task {
|
|
// pre-load the weights on launch to speed up the first generation
|
|
_ = try? await llm.load()
|
|
}
|
|
}
|
|
|
|
private func generate() {
|
|
Task {
|
|
await llm.generate(prompt: prompt)
|
|
}
|
|
}
|
|
private func copyToClipboard(_ string: String) {
|
|
#if os(macOS)
|
|
NSPasteboard.general.clearContents()
|
|
NSPasteboard.general.setString(string, forType: .string)
|
|
#else
|
|
UIPasteboard.general.string = string
|
|
#endif
|
|
}
|
|
}
|
|
|
|
@Observable
|
|
class LLMEvaluator {
|
|
|
|
@MainActor
|
|
var running = false
|
|
|
|
var output = ""
|
|
var modelInfo = ""
|
|
var stat = ""
|
|
|
|
/// this controls which model loads -- phi4bit is one of the smaller ones so this will fit on
|
|
/// more devices
|
|
let modelConfiguration = ModelConfiguration.phi4bit
|
|
|
|
/// parameters controlling the output
|
|
let generateParameters = GenerateParameters(temperature: 0.6)
|
|
let maxTokens = 240
|
|
|
|
/// update the display every N tokens -- 4 looks like it updates continuously
|
|
/// and is low overhead. observed ~15% reduction in tokens/s when updating
|
|
/// on every token
|
|
let displayEveryNTokens = 4
|
|
|
|
enum LoadState {
|
|
case idle
|
|
case loaded(LLMModel, Tokenizers.Tokenizer)
|
|
}
|
|
|
|
var loadState = LoadState.idle
|
|
|
|
/// load and return the model -- can be called multiple times, subsequent calls will
|
|
/// just return the loaded model
|
|
func load() async throws -> (LLMModel, Tokenizers.Tokenizer) {
|
|
switch loadState {
|
|
case .idle:
|
|
// limit the buffer cache
|
|
MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)
|
|
|
|
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
|
|
[modelConfiguration] progress in
|
|
DispatchQueue.main.sync {
|
|
self.modelInfo =
|
|
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
|
|
}
|
|
}
|
|
self.modelInfo =
|
|
"Loaded \(modelConfiguration.id). Weights: \(MLX.GPU.activeMemory / 1024 / 1024)M"
|
|
loadState = .loaded(model, tokenizer)
|
|
return (model, tokenizer)
|
|
|
|
case .loaded(let model, let tokenizer):
|
|
return (model, tokenizer)
|
|
}
|
|
}
|
|
|
|
func generate(prompt: String) async {
|
|
let canGenerate = await MainActor.run {
|
|
if running {
|
|
return false
|
|
} else {
|
|
running = true
|
|
self.output = ""
|
|
return true
|
|
}
|
|
}
|
|
|
|
guard canGenerate else { return }
|
|
|
|
do {
|
|
let (model, tokenizer) = try await load()
|
|
// augment the prompt as needed
|
|
let prompt = modelConfiguration.prepare(prompt: prompt)
|
|
let promptTokens = tokenizer.encode(text: prompt)
|
|
|
|
// each time you generate you will get something new
|
|
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
|
|
|
|
let result = await LLM.generate(
|
|
promptTokens: promptTokens, parameters: generateParameters, model: model,
|
|
tokenizer: tokenizer
|
|
) { tokens in
|
|
// update the output -- this will make the view show the text as it generates
|
|
if tokens.count % displayEveryNTokens == 0 {
|
|
let text = tokenizer.decode(tokens: tokens)
|
|
await MainActor.run {
|
|
self.output = text
|
|
}
|
|
}
|
|
|
|
if tokens.count >= maxTokens {
|
|
return .stop
|
|
} else {
|
|
return .more
|
|
}
|
|
}
|
|
|
|
// update the text if needed, e.g. we haven't displayed because of displayEveryNTokens
|
|
await MainActor.run {
|
|
if result.output != self.output {
|
|
self.output = result.output
|
|
}
|
|
running = false
|
|
self.stat = " Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))"
|
|
}
|
|
|
|
} catch {
|
|
await MainActor.run {
|
|
running = false
|
|
output = "Failed: \(error)"
|
|
}
|
|
}
|
|
}
|
|
}
|