implement LoRA / QLoRA (#46)
* implement LoRA / QLoRA - example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task - see also https://arxiv.org/abs/2106.09685 - based on https://github.com/ml-explore/mlx-examples/tree/main/lora * add some command line flags I found useful during use - --quiet -- don't print decorator text, just the generated text - --prompt @/tmp/file.txt -- load prompt from file * user can specify path to model OR model identifier in huggingface * update mlx-swift reference Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com> Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
@@ -187,7 +187,7 @@ class LLMEvaluator {
|
||||
[modelConfiguration] progress in
|
||||
DispatchQueue.main.sync {
|
||||
self.modelInfo =
|
||||
"Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%"
|
||||
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
|
||||
}
|
||||
}
|
||||
self.modelInfo =
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"colors" : [
|
||||
{
|
||||
"idiom" : "universal"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "16x16"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "16x16"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "32x32"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "32x32"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "128x128"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "128x128"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "256x256"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "256x256"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "512x512"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "512x512"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
284
Applications/LoRATrainingExample/ContentView.swift
Normal file
284
Applications/LoRATrainingExample/ContentView.swift
Normal file
@@ -0,0 +1,284 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import LLM
|
||||
import MLX
|
||||
import MLXOptimizers
|
||||
import MLXRandom
|
||||
import SwiftUI
|
||||
import Tokenizers
|
||||
|
||||
struct ContentView: View {
|
||||
|
||||
@State var evaluator = LoRAEvaluator()
|
||||
|
||||
@State var prompt = """
|
||||
table: 1-10015132-16
|
||||
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
|
||||
Q: What is terrence ross' nationality
|
||||
A:
|
||||
"""
|
||||
|
||||
var body: some View {
|
||||
VStack {
|
||||
HStack {
|
||||
if let progress = evaluator.progress {
|
||||
if let current = progress.current, let limit = progress.limit {
|
||||
ProgressView(progress.title, value: current, total: limit)
|
||||
} else {
|
||||
ProgressView(progress.title)
|
||||
}
|
||||
}
|
||||
}
|
||||
.frame(maxWidth: .infinity, minHeight: 25)
|
||||
|
||||
VStack {
|
||||
ScrollView(.vertical) {
|
||||
ScrollViewReader { sp in
|
||||
Group {
|
||||
Text(evaluator.output)
|
||||
.textSelection(.enabled)
|
||||
.frame(maxWidth: .infinity)
|
||||
}
|
||||
.onChange(of: evaluator.output) { _, _ in
|
||||
sp.scrollTo("bottom")
|
||||
}
|
||||
.padding()
|
||||
|
||||
Spacer()
|
||||
.frame(width: 1, height: 1)
|
||||
.id("bottom")
|
||||
}
|
||||
}
|
||||
|
||||
// controls for each of the different states
|
||||
VStack {
|
||||
switch evaluator.state {
|
||||
case .idle:
|
||||
Button("Start", action: start)
|
||||
|
||||
case .training:
|
||||
EmptyView()
|
||||
|
||||
case .evaluate:
|
||||
Group {
|
||||
TextEditor(text: $prompt)
|
||||
.frame(minHeight: 60)
|
||||
Button("Evaluate", action: evaluate)
|
||||
}
|
||||
.disabled(evaluator.progress != nil)
|
||||
|
||||
case .failed(let message):
|
||||
Text("Failed: \(message)")
|
||||
.bold()
|
||||
.foregroundStyle(.red)
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
.frame(maxWidth: .infinity)
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
|
||||
func start() {
|
||||
Task {
|
||||
await evaluator.start()
|
||||
}
|
||||
}
|
||||
|
||||
func evaluate() {
|
||||
Task {
|
||||
await evaluator.evaluate(prompt: prompt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Progress reporting with a title.
|
||||
struct Progress: Equatable {
|
||||
let title: String
|
||||
let current: Double?
|
||||
let limit: Double?
|
||||
}
|
||||
|
||||
@Observable
|
||||
class LoRAEvaluator {
|
||||
|
||||
enum State {
|
||||
case idle
|
||||
case training
|
||||
case evaluate
|
||||
case failed(String)
|
||||
}
|
||||
|
||||
enum ModelState {
|
||||
case idle
|
||||
case loaded(LLMModel, Tokenizer)
|
||||
}
|
||||
|
||||
var state = State.idle
|
||||
var progress: Progress?
|
||||
|
||||
var output = ""
|
||||
|
||||
private let modelConfiguration = ModelConfiguration.mistral7B4bit
|
||||
private var model: ModelState = .idle
|
||||
|
||||
private let loraLayers = 4
|
||||
private let learningRate: Float = 1e-5
|
||||
private let parameters = LoRATrain.Parameters(batchSize: 1, iterations: 200)
|
||||
|
||||
private let generateParameters = GenerateParameters(temperature: 0.6, topP: 0.9)
|
||||
private let evaluateShowEvery = 8
|
||||
private let maxTokens = 200
|
||||
|
||||
private func loadModel() async throws -> (LLMModel, Tokenizer) {
|
||||
switch self.model {
|
||||
case .idle:
|
||||
let name = modelConfiguration.name
|
||||
await MainActor.run {
|
||||
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
|
||||
}
|
||||
|
||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
|
||||
progress in
|
||||
if progress.fractionCompleted < 1.0 {
|
||||
DispatchQueue.main.sync {
|
||||
self.progress = .init(
|
||||
title: "Download \(name)", current: progress.fractionCompleted,
|
||||
limit: 1.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
eval(model)
|
||||
self.model = .loaded(model, tokenizer)
|
||||
return (model, tokenizer)
|
||||
|
||||
case .loaded(let model, let tokenizer):
|
||||
return (model, tokenizer)
|
||||
}
|
||||
}
|
||||
|
||||
private func loadLoRAData(name: String) throws -> [String]? {
|
||||
if let url = Bundle.main.url(forResource: name, withExtension: "jsonl") {
|
||||
return try LLM.loadLoRAData(url: url)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func start() async {
|
||||
do {
|
||||
try await startInner()
|
||||
} catch {
|
||||
self.state = .failed("Failed: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
private func startInner() async throws {
|
||||
// setup
|
||||
GPU.set(cacheLimit: 32 * 1024 * 1024)
|
||||
await MainActor.run {
|
||||
output = ""
|
||||
state = .training
|
||||
}
|
||||
|
||||
// load the model
|
||||
let (model, tokenizer) = try await loadModel()
|
||||
|
||||
// apply LoRA adapters and train
|
||||
guard let layerProvider = model as? LoRAModel else {
|
||||
state = .failed("Model must implement the LoRALayerProvider protocol")
|
||||
return
|
||||
}
|
||||
LoRATrain.convert(
|
||||
model: model, layers: Array(layerProvider.loraLinearLayers().suffix(loraLayers)))
|
||||
|
||||
let train = try loadLoRAData(name: "train")
|
||||
let valid = try loadLoRAData(name: "valid")
|
||||
guard let train, let valid else {
|
||||
state = .failed("Failed to load train/validation data")
|
||||
return
|
||||
}
|
||||
|
||||
let optimizer = Adam(learningRate: learningRate)
|
||||
try await LoRATrain.train(
|
||||
model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
|
||||
parameters: parameters
|
||||
) { progress in
|
||||
await MainActor.run {
|
||||
switch progress {
|
||||
case .train(let i, _, _, _):
|
||||
self.progress = .init(
|
||||
title: "Train", current: Double(i), limit: Double(parameters.iterations))
|
||||
case .validation:
|
||||
output += "\n"
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
output += progress.description + "\n"
|
||||
}
|
||||
|
||||
return .more
|
||||
}
|
||||
|
||||
// done training, test
|
||||
await MainActor.run {
|
||||
self.progress = .init(title: "Testing", current: nil, limit: nil)
|
||||
}
|
||||
guard let test = try loadLoRAData(name: "test") else {
|
||||
state = .failed("Failed to load test data")
|
||||
return
|
||||
}
|
||||
|
||||
let loss = LoRATrain.evaluate(
|
||||
model: model, dataset: test, tokenizer: tokenizer, batchSize: 1, batchCount: 0)
|
||||
await MainActor.run {
|
||||
self.progress = nil
|
||||
self.output += "\n"
|
||||
self.output += "Test loss \(loss.formatted()), ppl \(exp(loss).formatted())\n"
|
||||
self.state = .evaluate
|
||||
}
|
||||
}
|
||||
|
||||
func evaluate(prompt: String) async {
|
||||
do {
|
||||
try await evaluateInner(prompt: prompt)
|
||||
} catch {
|
||||
self.state = .failed("Failed: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
func evaluateInner(prompt: String) async throws {
|
||||
await MainActor.run {
|
||||
self.progress = .init(title: "Evaluating", current: nil, limit: nil)
|
||||
self.output = ""
|
||||
}
|
||||
|
||||
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
|
||||
|
||||
let (model, tokenizer) = try await loadModel()
|
||||
|
||||
// prepare the prompt
|
||||
let preparedPrompt = modelConfiguration.prepare(prompt: prompt)
|
||||
let promptTokens = tokenizer.encode(text: preparedPrompt)
|
||||
|
||||
// evaluate
|
||||
let result = await LLM.generate(
|
||||
promptTokens: promptTokens, parameters: generateParameters, model: model,
|
||||
tokenizer: tokenizer,
|
||||
didGenerate: { tokens in
|
||||
if tokens.count % evaluateShowEvery == 0 {
|
||||
let fullOutput = tokenizer.decode(tokens: tokens)
|
||||
await MainActor.run {
|
||||
self.output = fullOutput
|
||||
}
|
||||
}
|
||||
return tokens.count >= maxTokens ? .stop : .more
|
||||
})
|
||||
|
||||
await MainActor.run {
|
||||
self.output = result.output
|
||||
self.progress = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>com.apple.developer.kernel.increased-memory-limit</key>
|
||||
<true/>
|
||||
<key>com.apple.security.app-sandbox</key>
|
||||
<true/>
|
||||
<key>com.apple.security.files.user-selected.read-only</key>
|
||||
<true/>
|
||||
<key>com.apple.security.network.client</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
@@ -0,0 +1,12 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import SwiftUI
|
||||
|
||||
@main
|
||||
struct LoRATrainingExampleApp: App {
|
||||
var body: some Scene {
|
||||
WindowGroup {
|
||||
ContentView()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
21
Applications/LoRATrainingExample/README.md
Normal file
21
Applications/LoRATrainingExample/README.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# LoRATrainingExample
|
||||
|
||||
Example application that:
|
||||
|
||||
- downloads the `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model from huggingface
|
||||
- loads the train/valid/test data from `$SRCROOT/Data/lora` (this is copied into the build but you can imagine how it might be downloaded)
|
||||
- adds LoRA adapters and trains the model
|
||||
- let's you evaluate a prompt against the model
|
||||
|
||||
This roughly equates to the command line example in [Tools/llm-tool](../../Tools/llm-tool) and
|
||||
you can read more about LoRA there.
|
||||
|
||||
This evaluates the LoRA adapted model rather than a fused model. This doesn't persist
|
||||
the LoRA weights or the fused model -- it will retrain it each time the program is launched.
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
The `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model requires a little over 4G of
|
||||
memory to load an train -- this may require ~6G of physical RAM.
|
||||
|
||||
|
||||
100
Data/lora/test.jsonl
Normal file
100
Data/lora/test.jsonl
Normal file
@@ -0,0 +1,100 @@
|
||||
{"text": "table: 1-10015132-16\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What is terrence ross' nationality\nA: SELECT Nationality FROM 1-10015132-16 WHERE Player = 'Terrence Ross'"}
|
||||
{"text": "table: 1-10015132-16\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What clu was in toronto 1995-96\nA: SELECT School/Club Team FROM 1-10015132-16 WHERE Years in Toronto = '1995-96'"}
|
||||
{"text": "table: 1-10015132-16\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: which club was in toronto 2003-06\nA: SELECT School/Club Team FROM 1-10015132-16 WHERE Years in Toronto = '2003-06'"}
|
||||
{"text": "table: 1-10015132-16\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: how many schools or teams had jalen rose\nA: SELECT COUNT School/Club Team FROM 1-10015132-16 WHERE Player = 'Jalen Rose'"}
|
||||
{"text": "table: 1-10083598-1\ncolumns: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report\nQ: Where was Assen held?\nA: SELECT Round FROM 1-10083598-1 WHERE Circuit = 'Assen'"}
|
||||
{"text": "table: 1-10083598-1\ncolumns: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report\nQ: What was the number of race that Kevin Curtain won?\nA: SELECT COUNT No FROM 1-10083598-1 WHERE Pole Position = 'Kevin Curtain'"}
|
||||
{"text": "table: 1-10083598-1\ncolumns: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report\nQ: What was the date of the race in Misano?\nA: SELECT Date FROM 1-10083598-1 WHERE Circuit = 'Misano'"}
|
||||
{"text": "table: 1-1013129-2\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: How many different positions did Sherbrooke Faucons (qmjhl) provide in the draft?\nA: SELECT COUNT Position FROM 1-1013129-2 WHERE College/junior/club team = 'Sherbrooke Faucons (QMJHL)'"}
|
||||
{"text": "table: 1-1013129-2\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: What are the nationalities of the player picked from Thunder Bay Flyers (ushl)\nA: SELECT Nationality FROM 1-1013129-2 WHERE College/junior/club team = 'Thunder Bay Flyers (USHL)'"}
|
||||
{"text": "table: 1-1013129-2\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: How many different college/junior/club teams provided a player to the Washington Capitals NHL Team?\nA: SELECT COUNT College/junior/club team FROM 1-1013129-2 WHERE NHL team = 'Washington Capitals'"}
|
||||
{"text": "table: 1-1013129-3\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: How many different nationalities do the players of New Jersey Devils come from?\nA: SELECT COUNT Nationality FROM 1-1013129-3 WHERE NHL team = 'New Jersey Devils'"}
|
||||
{"text": "table: 1-1013129-3\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: What's Dorain Anneck's pick number?\nA: SELECT Pick FROM 1-1013129-3 WHERE Player = 'Dorain Anneck'"}
|
||||
{"text": "table: 1-1013129-3\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: What is the nationality of the player from Vancouver Canucks?\nA: SELECT Nationality FROM 1-1013129-3 WHERE NHL team = 'Vancouver Canucks'"}
|
||||
{"text": "table: 1-1013129-3\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: What's the pick number of the player from Springfield Olympics (Nejhl)?\nA: SELECT Pick FROM 1-1013129-3 WHERE College/junior/club team = 'Springfield Olympics (NEJHL)'"}
|
||||
{"text": "table: 1-1014206-2\ncolumns: #, Shipyard, Laid down, Launched, Commissioned, Fleet, Status\nQ: When were the ships launched that were laid down on september 1, 1964?\nA: SELECT Launched FROM 1-1014206-2 WHERE Laid down = 'September 1, 1964'"}
|
||||
{"text": "table: 1-1014206-2\ncolumns: #, Shipyard, Laid down, Launched, Commissioned, Fleet, Status\nQ: List the # for ships commissioned on december 18, 1965.\nA: SELECT # FROM 1-1014206-2 WHERE Commissioned = 'December 18, 1965'"}
|
||||
{"text": "table: 1-1014206-2\ncolumns: #, Shipyard, Laid down, Launched, Commissioned, Fleet, Status\nQ: List the # for ships commissioned on september 30, 1967.\nA: SELECT # FROM 1-1014206-2 WHERE Commissioned = 'September 30, 1967'"}
|
||||
{"text": "table: 1-1014206-2\ncolumns: #, Shipyard, Laid down, Launched, Commissioned, Fleet, Status\nQ: When were ships laid down that were commissioned on october 29, 1965?\nA: SELECT Laid down FROM 1-1014206-2 WHERE Commissioned = 'October 29, 1965'"}
|
||||
{"text": "table: 1-1015521-2\ncolumns: Equivalent NATO Rank Code, Rank in Spanish, Rank in English, Commonwealth equivalent, US Air Force equivalent\nQ: What could a spanish coronel be addressed as in the commonwealth military?\nA: SELECT Commonwealth equivalent FROM 1-1015521-2 WHERE Rank in Spanish = 'Coronel'"}
|
||||
{"text": "table: 1-1015521-2\ncolumns: Equivalent NATO Rank Code, Rank in Spanish, Rank in English, Commonwealth equivalent, US Air Force equivalent\nQ: Give me a list of all spanish officer titles that could receive recognition as group captain in english\nA: SELECT Rank in English FROM 1-1015521-2 WHERE Commonwealth equivalent = 'Group Captain'"}
|
||||
{"text": "table: 1-1015521-2\ncolumns: Equivalent NATO Rank Code, Rank in Spanish, Rank in English, Commonwealth equivalent, US Air Force equivalent\nQ: If you are a pilot officer in the commonwealth then what will you called as in the US air force?\nA: SELECT US Air Force equivalent FROM 1-1015521-2 WHERE Commonwealth equivalent = 'Pilot Officer'"}
|
||||
{"text": "table: 1-1015521-2\ncolumns: Equivalent NATO Rank Code, Rank in Spanish, Rank in English, Commonwealth equivalent, US Air Force equivalent\nQ: If you're a major general in the US air force then what ranking will you receive in the commonwealth's air force?\nA: SELECT Commonwealth equivalent FROM 1-1015521-2 WHERE US Air Force equivalent = 'Major General'"}
|
||||
{"text": "table: 1-1015521-2\ncolumns: Equivalent NATO Rank Code, Rank in Spanish, Rank in English, Commonwealth equivalent, US Air Force equivalent\nQ: If you get a ranking as major in the english military then what would the spanish military address you as? \nA: SELECT Rank in Spanish FROM 1-1015521-2 WHERE Rank in English = 'Major'"}
|
||||
{"text": "table: 1-10182508-5\ncolumns: Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank., Wrestler, # of reigns, Combined defenses, Combined days\nQ: Which wrestlers have had 2 reigns?\nA: SELECT Wrestler FROM 1-10182508-5 WHERE # of reigns = 2"}
|
||||
{"text": "table: 1-10182508-5\ncolumns: Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank., Wrestler, # of reigns, Combined defenses, Combined days\nQ: In terms of reigns, what is the lowest number listed?\nA: SELECT MIN # of reigns FROM 1-10182508-5"}
|
||||
{"text": "table: 1-10182508-5\ncolumns: Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank., Wrestler, # of reigns, Combined defenses, Combined days\nQ: What rank was Bryan Danielson in this chart?\nA: SELECT Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank. FROM 1-10182508-5 WHERE Wrestler = 'Bryan Danielson'"}
|
||||
{"text": "table: 1-10182508-5\ncolumns: Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank., Wrestler, # of reigns, Combined defenses, Combined days\nQ: How many combined days did Go Shiozaki have?\nA: SELECT Combined days FROM 1-10182508-5 WHERE Wrestler = 'Go Shiozaki'"}
|
||||
{"text": "table: 1-10182508-5\ncolumns: Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank., Wrestler, # of reigns, Combined defenses, Combined days\nQ: What was Go Shiozaki's rank?\nA: SELECT MIN Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank. FROM 1-10182508-5 WHERE Wrestler = 'Go Shiozaki'"}
|
||||
{"text": "table: 1-1024710-2\ncolumns: Member, Electorate, Province, MPs term, Election date\nQ: Which province is grey and bell electorate in\nA: SELECT Province FROM 1-1024710-2 WHERE Electorate = 'Grey and Bell'"}
|
||||
{"text": "table: 1-1024710-2\ncolumns: Member, Electorate, Province, MPs term, Election date\nQ: Which province is bay of islands in\nA: SELECT Province FROM 1-1024710-2 WHERE Electorate = 'Bay of Islands'"}
|
||||
{"text": "table: 1-10294071-1\ncolumns: Player, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L, Ties played, Debut, Years played\nQ: what is the total number of\u00a0total w\u2013l\u00a0where\u00a0doubles w\u2013l\u00a0is 11\u201311\nA: SELECT COUNT Total W\u2013L FROM 1-10294071-1 WHERE Doubles W\u2013L = '11\u201311'"}
|
||||
{"text": "table: 1-10294071-1\ncolumns: Player, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L, Ties played, Debut, Years played\nQ: what is the total number of\u00a0singles w\u2013l\u00a0where\u00a0doubles w\u2013l\u00a0is 11\u201314\nA: SELECT COUNT Singles W\u2013L FROM 1-10294071-1 WHERE Doubles W\u2013L = '11\u201314'"}
|
||||
{"text": "table: 1-10294071-1\ncolumns: Player, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L, Ties played, Debut, Years played\nQ: what's the\u00a0total w\u2013l\u00a0where\u00a0player\u00a0is boro jovanovi\u0107 category:articles with hcards\nA: SELECT Total W\u2013L FROM 1-10294071-1 WHERE Player = 'Boro Jovanovi\u0107 Category:Articles with hCards'"}
|
||||
{"text": "table: 1-10294071-1\ncolumns: Player, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L, Ties played, Debut, Years played\nQ: what is the maximum\u00a0ties played\u00a0where\u00a0player\u00a0is josip palada category:articles with hcards\nA: SELECT MAX Ties played FROM 1-10294071-1 WHERE Player = 'Josip Palada Category:Articles with hCards'"}
|
||||
{"text": "table: 1-10294071-1\ncolumns: Player, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L, Ties played, Debut, Years played\nQ: what is the total number of\u00a0ties played\u00a0where\u00a0total w\u2013l\u00a0is 38\u201324\nA: SELECT COUNT Ties played FROM 1-10294071-1 WHERE Total W\u2013L = '38\u201324'"}
|
||||
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: What is the Frequency at the Market/Rank of Burlington - Plattsburgh , Vermont - New York /143?\nA: SELECT COUNT Frequency FROM 1-10333757-1 WHERE Market/Rank = 'Burlington - Plattsburgh , Vermont - New York /143'"}
|
||||
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: What is the Branding for Group Owner Qantam of Cape Cod, LLC?\nA: SELECT Branding FROM 1-10333757-1 WHERE Group owner = 'Qantam of Cape Cod, LLC'"}
|
||||
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: What Branding does WRKO calls use?\nA: SELECT Branding FROM 1-10333757-1 WHERE Calls = 'WRKO'"}
|
||||
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: What is the Format for Branding of 1290 wkbk w281au 104.1?\nA: SELECT Format FROM 1-10333757-1 WHERE Branding = '1290 WKBK W281AU 104.1'"}
|
||||
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: Which Market/Rank is associated with WCRN calls?\nA: SELECT Market/Rank FROM 1-10333757-1 WHERE Calls = 'WCRN'"}
|
||||
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: Which Frequency is used for WEGP calls?\nA: SELECT Frequency FROM 1-10333757-1 WHERE Calls = 'WEGP'"}
|
||||
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: What is the regulated retail price for the tariff code ff0 prs?\nA: SELECT BTs retail price (regulated) FROM 1-10408617-5 WHERE Tariff code = 'ff0 PRS'"}
|
||||
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: What is the premium associated with tariff code g9?\nA: SELECT Approx premium FROM 1-10408617-5 WHERE Tariff code = 'g9'"}
|
||||
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: How many tariff codes have a bts retail price of 2p/min or inclusive?\nA: SELECT COUNT Tariff code FROM 1-10408617-5 WHERE BTs retail price (regulated) = '2p/min or inclusive'"}
|
||||
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: How many tariff codes have a bts retail price of 2.553p/min?\nA: SELECT COUNT Tariff code FROM 1-10408617-5 WHERE BTs retail price (regulated) = '2.553p/min'"}
|
||||
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: What prefixes are priced at pence per minute, fixed at all times with a premium of 3p/min?\nA: SELECT Prefixes FROM 1-10408617-5 WHERE Scheme = 'Pence per minute, fixed at all times' AND Approx premium = '3p/min'"}
|
||||
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: What is the bts retail price (regulated) for tariff code g10?\nA: SELECT BTs retail price (regulated) FROM 1-10408617-5 WHERE Tariff code = 'g10'"}
|
||||
{"text": "table: 1-10409754-5\ncolumns: Nominative, Old orthography, New orthography, /e/ or /\u00e6/ (IPA), Tone (Latvian notation: /~/ - level, /^/ - broken), Translation\nQ: What is the tone for gen.sing. plague?\nA: SELECT Tone (Latvian notation: /~/ - level, /^/ - broken) FROM 1-10409754-5 WHERE Translation = 'Gen.Sing. plague'"}
|
||||
{"text": "table: 1-10432351-1\ncolumns: Star (Pismis24-#), Spectral type, Magnitude (M bol ), Temperature (K), Radius (R \u2609 ), Mass (M \u2609 )\nQ: What is the smallest possible radius?\nA: SELECT MIN Radius (R \u2609 ) FROM 1-10432351-1"}
|
||||
{"text": "table: 1-10432351-1\ncolumns: Star (Pismis24-#), Spectral type, Magnitude (M bol ), Temperature (K), Radius (R \u2609 ), Mass (M \u2609 )\nQ: What are all the spectral types for star mismis24-# is 1sw?\nA: SELECT Spectral type FROM 1-10432351-1 WHERE Star (Pismis24-#) = '1SW'"}
|
||||
{"text": "table: 1-10432351-1\ncolumns: Star (Pismis24-#), Spectral type, Magnitude (M bol ), Temperature (K), Radius (R \u2609 ), Mass (M \u2609 )\nQ: If a radius is 10, what is the lowest possible mass?\nA: SELECT MIN Mass (M \u2609 ) FROM 1-10432351-1 WHERE Radius (R \u2609 ) = 10"}
|
||||
{"text": "table: 1-105344-2\ncolumns: Year, Aircraft kilometers, Departures, Flying hours, Passengers, Seat factor, Employees, Profit/loss\nQ: What percentage of seats were filled in 2006?\nA: SELECT Seat factor FROM 1-105344-2 WHERE Year = 2006"}
|
||||
{"text": "table: 1-105344-2\ncolumns: Year, Aircraft kilometers, Departures, Flying hours, Passengers, Seat factor, Employees, Profit/loss\nQ: How many hours were flown in each of the years where more than 64379058.0 kilometers were flown?\nA: SELECT Flying hours FROM 1-105344-2 WHERE Aircraft kilometers > 64379058.0"}
|
||||
{"text": "table: 1-105344-2\ncolumns: Year, Aircraft kilometers, Departures, Flying hours, Passengers, Seat factor, Employees, Profit/loss\nQ: Of the years that had exactly 17096 departures, what is the greatest number of aircraft kilometers flown?\nA: SELECT MAX Aircraft kilometers FROM 1-105344-2 WHERE Departures = 17096"}
|
||||
{"text": "table: 1-10548224-1\ncolumns: Year, Game or event, Date contested, League or governing body, Sport, Winning team, Losing team, Final score\nQ: Which winning team beat the New York Yankees?\nA: SELECT Winning team FROM 1-10548224-1 WHERE Losing team = 'New York Yankees'"}
|
||||
{"text": "table: 1-10548224-1\ncolumns: Year, Game or event, Date contested, League or governing body, Sport, Winning team, Losing team, Final score\nQ: What was the final score for the game that was contested on February 1, 2009?\nA: SELECT Final score FROM 1-10548224-1 WHERE Date contested = 'February 1, 2009'"}
|
||||
{"text": "table: 1-10548224-1\ncolumns: Year, Game or event, Date contested, League or governing body, Sport, Winning team, Losing team, Final score\nQ: What sport had a final score of 3-2?\nA: SELECT Sport FROM 1-10548224-1 WHERE Final score = '3-2'"}
|
||||
{"text": "table: 1-10548224-1\ncolumns: Year, Game or event, Date contested, League or governing body, Sport, Winning team, Losing team, Final score\nQ: Who was the winning team of the game that was contested on February 1, 2009?\nA: SELECT Winning team FROM 1-10548224-1 WHERE Date contested = 'February 1, 2009'"}
|
||||
{"text": "table: 1-10548224-1\ncolumns: Year, Game or event, Date contested, League or governing body, Sport, Winning team, Losing team, Final score\nQ: Who was the losing team of the game that was contested on February 1, 2004?\nA: SELECT Losing team FROM 1-10548224-1 WHERE Date contested = 'February 1, 2004'"}
|
||||
{"text": "table: 1-1057262-2\ncolumns: Crop (kilotonnes), New South Wales, Victoria, Queensland, Western Australia, South Australia, Tasmania, Total\nQ: what's the minimum\u00a0total\u00a0with\u00a0crop (kilotonnes)\u00a0being s lupin\nA: SELECT MIN Total FROM 1-1057262-2 WHERE Crop (kilotonnes) = 's Lupin'"}
|
||||
{"text": "table: 1-1057262-2\ncolumns: Crop (kilotonnes), New South Wales, Victoria, Queensland, Western Australia, South Australia, Tasmania, Total\nQ: what's the\u00a0new south wales\u00a0with\u00a0crop (kilotonnes)\u00a0being canola\nA: SELECT New South Wales FROM 1-1057262-2 WHERE Crop (kilotonnes) = 'Canola'"}
|
||||
{"text": "table: 1-1057262-2\ncolumns: Crop (kilotonnes), New South Wales, Victoria, Queensland, Western Australia, South Australia, Tasmania, Total\nQ: what's the total number of\u00a0south australia\u00a0with\u00a0victoria\u00a0value of 2173\nA: SELECT COUNT South Australia FROM 1-1057262-2 WHERE Victoria = 2173"}
|
||||
{"text": "table: 1-1057262-2\ncolumns: Crop (kilotonnes), New South Wales, Victoria, Queensland, Western Australia, South Australia, Tasmania, Total\nQ: what's the minimum\u00a0tasmania value\nA: SELECT MIN Tasmania FROM 1-1057262-2"}
|
||||
{"text": "table: 1-1057262-2\ncolumns: Crop (kilotonnes), New South Wales, Victoria, Queensland, Western Australia, South Australia, Tasmania, Total\nQ: what's the total number of\u00a0tasmania\u00a0with\u00a0new south wales\u00a0crop of 190 kilotonnes\nA: SELECT COUNT Tasmania FROM 1-1057262-2 WHERE New South Wales = 190"}
|
||||
{"text": "table: 1-1058787-1\ncolumns: Approximate Age, Virtues, Psycho Social Crisis, Significant Relationship, Existential Question [ not in citation given ], Examples\nQ: How many significant relationships list Will as a virtue?\nA: SELECT COUNT Significant Relationship FROM 1-1058787-1 WHERE Virtues = 'Will'"}
|
||||
{"text": "table: 1-1058787-1\ncolumns: Approximate Age, Virtues, Psycho Social Crisis, Significant Relationship, Existential Question [ not in citation given ], Examples\nQ: Which examples ask the existential question \"Can I Love?\"\nA: SELECT Examples FROM 1-1058787-1 WHERE Existential Question [ not in citation given ] = 'Can I Love?'"}
|
||||
{"text": "table: 1-1059743-2\ncolumns: Rank, Member Association, Points, Group stage, Play-off, AFC Cup\nQ: How many countries got 796.7 points?\nA: SELECT COUNT Rank FROM 1-1059743-2 WHERE Points = '796.7'"}
|
||||
{"text": "table: 1-1059743-2\ncolumns: Rank, Member Association, Points, Group stage, Play-off, AFC Cup\nQ: In what group stage were 177.2 points awarded?\nA: SELECT COUNT Group stage FROM 1-1059743-2 WHERE Points = '177.2'"}
|
||||
{"text": "table: 1-1059743-2\ncolumns: Rank, Member Association, Points, Group stage, Play-off, AFC Cup\nQ: What is the lowest group to earn 886.6 points?\nA: SELECT MIN Group stage FROM 1-1059743-2 WHERE Points = '886.6'"}
|
||||
{"text": "table: 1-1059743-2\ncolumns: Rank, Member Association, Points, Group stage, Play-off, AFC Cup\nQ: How many countries earned 177.2 points?\nA: SELECT COUNT Member Association FROM 1-1059743-2 WHERE Points = '177.2'"}
|
||||
{"text": "table: 1-10586064-2\ncolumns: County, Precincts, Lunsford, % Lunsford, McConnell, % McConnell, Total\nQ: If % lunsford is 51.82% what is the % mcconnell in Letcher?\nA: SELECT % McConnell FROM 1-10586064-2 WHERE % Lunsford = '51.82%'"}
|
||||
{"text": "table: 1-10586064-2\ncolumns: County, Precincts, Lunsford, % Lunsford, McConnell, % McConnell, Total\nQ: What country had the total 18,900 (r)?\nA: SELECT County FROM 1-10586064-2 WHERE Total = '18,900 (R)'"}
|
||||
{"text": "table: 1-10586064-2\ncolumns: County, Precincts, Lunsford, % Lunsford, McConnell, % McConnell, Total\nQ: When % mcconnell is 44.54% what are the total number of counties?\nA: SELECT COUNT County FROM 1-10586064-2 WHERE % McConnell = '44.54%'"}
|
||||
{"text": "table: 1-10586064-2\ncolumns: County, Precincts, Lunsford, % Lunsford, McConnell, % McConnell, Total\nQ: If % mcconnell is 47.17% what is the total number of mcconnell ?\nA: SELECT COUNT McConnell FROM 1-10586064-2 WHERE % McConnell = '47.17%'"}
|
||||
{"text": "table: 1-10586064-2\ncolumns: County, Precincts, Lunsford, % Lunsford, McConnell, % McConnell, Total\nQ: What is the county of precints 515?\nA: SELECT County FROM 1-10586064-2 WHERE Precincts = 515"}
|
||||
{"text": "table: 1-10601843-2\ncolumns: Stadium, Capacity, City, Country, Tenant, Opening\nQ: Which city has a capacity of 41903?\nA: SELECT City FROM 1-10601843-2 WHERE Capacity = 41903"}
|
||||
{"text": "table: 1-10601843-2\ncolumns: Stadium, Capacity, City, Country, Tenant, Opening\nQ: What is the maximum capacity of the Otkrytie Arena stadium?\nA: SELECT MAX Capacity FROM 1-10601843-2 WHERE Stadium = 'Otkrytie Arena'"}
|
||||
{"text": "table: 1-10601843-2\ncolumns: Stadium, Capacity, City, Country, Tenant, Opening\nQ: When did the stadium where Bursaspor is the tenant open?\nA: SELECT MIN Opening FROM 1-10601843-2 WHERE Tenant = 'Bursaspor'"}
|
||||
{"text": "table: 1-10601843-2\ncolumns: Stadium, Capacity, City, Country, Tenant, Opening\nQ: How many tenants are there in the city of Samsun?\nA: SELECT COUNT Tenant FROM 1-10601843-2 WHERE City = 'Samsun'"}
|
||||
{"text": "table: 1-10610087-5\ncolumns: No. in series, No. in season, Title, Directed by, Written by, Original air date\nQ: what's the\u00a0original air date\u00a0with\u00a0title\u00a0 \"hell\"\nA: SELECT Original air date FROM 1-10610087-5 WHERE Title = '\"Hell\"'"}
|
||||
{"text": "table: 1-10638523-1\ncolumns: Particulars and Characteristics, Shivalik Zone, Mid-Hill Zone, High hill zone, Trance- n Himalaya Zone\nQ: What is the percentage of the Shivalik Zone where the percentage of the Mid-Hill Zone is 10%?\nA: SELECT Shivalik Zone FROM 1-10638523-1 WHERE Mid-Hill Zone = '10%'"}
|
||||
{"text": "table: 1-10638523-1\ncolumns: Particulars and Characteristics, Shivalik Zone, Mid-Hill Zone, High hill zone, Trance- n Himalaya Zone\nQ: For mid-hill zone what is the altitude?\nA: SELECT Mid-Hill Zone FROM 1-10638523-1 WHERE Particulars and Characteristics = 'Altitude'"}
|
||||
{"text": "table: 1-10638523-1\ncolumns: Particulars and Characteristics, Shivalik Zone, Mid-Hill Zone, High hill zone, Trance- n Himalaya Zone\nQ: What are the climatic conditions for the trance- n himalaya zone?\nA: SELECT Trance- n Himalaya Zone FROM 1-10638523-1 WHERE Particulars and Characteristics = 'Climatic conditions'"}
|
||||
{"text": "table: 1-10638523-1\ncolumns: Particulars and Characteristics, Shivalik Zone, Mid-Hill Zone, High hill zone, Trance- n Himalaya Zone\nQ: What is the percentage of the trance- n himalaya zone that corresponds with the high hill zone is 25%?\nA: SELECT Trance- n Himalaya Zone FROM 1-10638523-1 WHERE High hill zone = '25%'"}
|
||||
{"text": "table: 1-10644188-3\ncolumns: Total tenure rank, Uninterrupted rank, Name, State represented, Dates of service, Total tenure time, Uninterrupted time\nQ: What is the state of Ted Stevens?\nA: SELECT State represented FROM 1-10644188-3 WHERE Name = 'Ted Stevens'"}
|
||||
{"text": "table: 1-10682862-68\ncolumns: Country, Players, Standard, Minor, First title, Last title\nQ: What's the standard of the country who won its first title in 1992?\nA: SELECT MAX Standard FROM 1-10682862-68 WHERE First title = 1992"}
|
||||
{"text": "table: 1-10682862-68\ncolumns: Country, Players, Standard, Minor, First title, Last title\nQ: What's the smallest number of players?\nA: SELECT MIN Players FROM 1-10682862-68"}
|
||||
{"text": "table: 1-10682862-68\ncolumns: Country, Players, Standard, Minor, First title, Last title\nQ: In what year was the last last title received, by any of the countries?\nA: SELECT MAX Last title FROM 1-10682862-68"}
|
||||
{"text": "table: 1-10710364-1\ncolumns: Religious group, Population % 1961, Population % 1971, Population % 1981, Population % 1991, Population % 2001\nQ: What religious groups made up 0.72% of the Indian population in 2001?\nA: SELECT Religious group FROM 1-10710364-1 WHERE Population % 2001 = '0.72%'"}
|
||||
{"text": "table: 1-10718868-2\ncolumns: No. in series, No. in season, Title, Directed by, Written by, Original air date, U.S. viewers (millions)\nQ: What is the original air date for episode 15 of season 6?\nA: SELECT Original air date FROM 1-10718868-2 WHERE No. in season = 15"}
|
||||
{"text": "table: 1-10718868-2\ncolumns: No. in series, No. in season, Title, Directed by, Written by, Original air date, U.S. viewers (millions)\nQ: How many episodes in season 6 titles \"Poppin' Tags\"?\nA: SELECT COUNT No. in season FROM 1-10718868-2 WHERE Title = '\"Poppin' Tags\"'"}
|
||||
{"text": "table: 1-10753917-1\ncolumns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat\nQ: Which podiums did the Williams team have with a margin of defeat of 2?\nA: SELECT Podiums FROM 1-10753917-1 WHERE Team = 'Williams' AND Margin of defeat = '2'"}
|
||||
{"text": "table: 1-10753917-1\ncolumns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat\nQ: How many drivers on the williams team had a margin of defeat of 2?\nA: SELECT COUNT Driver FROM 1-10753917-1 WHERE Team = 'Williams' AND Margin of defeat = '2'"}
|
||||
{"text": "table: 1-10753917-1\ncolumns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat\nQ: How many seasons was clay regazzoni the driver?\nA: SELECT COUNT Season FROM 1-10753917-1 WHERE Driver = 'Clay Regazzoni'"}
|
||||
{"text": "table: 1-10753917-1\ncolumns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat\nQ: Which margin of defeats had points of 30?\nA: SELECT Margin of defeat FROM 1-10753917-1 WHERE Points = '30'"}
|
||||
{"text": "table: 1-10753917-1\ncolumns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat\nQ: Which podiums did the alfa romeo team have?\nA: SELECT Podiums FROM 1-10753917-1 WHERE Team = 'Alfa Romeo'"}
|
||||
{"text": "table: 1-10797636-1\ncolumns: Village (German), Village (Slovene), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: What was the percent of slovenes 1951 for bach?\nA: SELECT Percent of Slovenes 1951 FROM 1-10797636-1 WHERE Village (German) = 'Bach'"}
|
||||
{"text": "table: 1-10812403-4\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: What college's team is the Saskatchewan Roughriders?\nA: SELECT College FROM 1-10812403-4 WHERE CFL Team = 'Saskatchewan Roughriders'"}
|
||||
{"text": "table: 1-10812403-4\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: What position did Calvin Mccarty play?\nA: SELECT Position FROM 1-10812403-4 WHERE Player = 'Calvin McCarty'"}
|
||||
{"text": "table: 1-10812403-4\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: How many people were pick #30?\nA: SELECT COUNT Position FROM 1-10812403-4 WHERE Pick # = 30"}
|
||||
1000
Data/lora/train.jsonl
Normal file
1000
Data/lora/train.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
100
Data/lora/valid.jsonl
Normal file
100
Data/lora/valid.jsonl
Normal file
@@ -0,0 +1,100 @@
|
||||
{"text": "table: 1-10015132-11\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What position does the player who played for butler cc (ks) play?\nA: SELECT Position FROM 1-10015132-11 WHERE School/Club Team = 'Butler CC (KS)'"}
|
||||
{"text": "table: 1-10015132-11\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: How many schools did player number 3 play at?\nA: SELECT COUNT School/Club Team FROM 1-10015132-11 WHERE No. = '3'"}
|
||||
{"text": "table: 1-10015132-11\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What school did player number 21 play for?\nA: SELECT School/Club Team FROM 1-10015132-11 WHERE No. = '21'"}
|
||||
{"text": "table: 1-10015132-11\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: Who is the player that wears number 42?\nA: SELECT Player FROM 1-10015132-11 WHERE No. = '42'"}
|
||||
{"text": "table: 1-10015132-11\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What player played guard for toronto in 1996-97?\nA: SELECT Player FROM 1-10015132-11 WHERE Position = 'Guard' AND Years in Toronto = '1996-97'"}
|
||||
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: Who are all of the players on the Westchester High School club team?\nA: SELECT Player FROM 1-10015132-9 WHERE School/Club Team = 'Westchester High School'"}
|
||||
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What school/club team is Amir Johnson on?\nA: SELECT School/Club Team FROM 1-10015132-9 WHERE Player = 'Amir Johnson'"}
|
||||
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What are the total amount of numbers on the Toronto team in 2005-06?\nA: SELECT COUNT No. FROM 1-10015132-9 WHERE Years in Toronto = '2005-06'"}
|
||||
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What are the total number of positions on the Toronto team in 2006-07?\nA: SELECT COUNT Position FROM 1-10015132-9 WHERE Years in Toronto = '2006-07'"}
|
||||
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What are the nationality of the players on the Fresno State school/club team?\nA: SELECT Nationality FROM 1-10015132-9 WHERE School/Club Team = 'Fresno State'"}
|
||||
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What school/club team is Trey Johnson on?\nA: SELECT School/Club Team FROM 1-10015132-9 WHERE Player = 'Trey Johnson'"}
|
||||
{"text": "table: 1-10026563-1\ncolumns: Entered office as Head of State or Government, Began time as senior G8 leader, Ended time as senior G8 leader, Person, Office\nQ: When did Jacques Chirac stop being a G8 leader?\nA: SELECT Ended time as senior G8 leader FROM 1-10026563-1 WHERE Person = 'Jacques Chirac'"}
|
||||
{"text": "table: 1-10026563-1\ncolumns: Entered office as Head of State or Government, Began time as senior G8 leader, Ended time as senior G8 leader, Person, Office\nQ: When did the Prime Minister of Italy take office?\nA: SELECT Entered office as Head of State or Government FROM 1-10026563-1 WHERE Office = 'Prime Minister of Italy'"}
|
||||
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: What is the English name of the country whose official native language is Dutch Papiamento?\nA: SELECT Country ( exonym ) FROM 1-1008653-1 WHERE Official or native language(s) (alphabet/script) = 'Dutch Papiamento'"}
|
||||
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: What official or native languages are spoken in the country whose capital city is Canberra?\nA: SELECT Official or native language(s) (alphabet/script) FROM 1-1008653-1 WHERE Capital ( exonym ) = 'Canberra'"}
|
||||
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: What is the local name given to the city of Canberra?\nA: SELECT Capital ( endonym ) FROM 1-1008653-1 WHERE Capital ( exonym ) = 'Canberra'"}
|
||||
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: What is the local name given to the capital of Anguilla?\nA: SELECT Capital ( endonym ) FROM 1-1008653-1 WHERE Country ( endonym ) = 'Anguilla'"}
|
||||
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: What is the English name given to the city of St. John's?\nA: SELECT Capital ( exonym ) FROM 1-1008653-1 WHERE Capital ( endonym ) = 'St. John's'"}
|
||||
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: How many capital cities does Australia have?\nA: SELECT COUNT Capital ( endonym ) FROM 1-1008653-1 WHERE Country ( endonym ) = 'Australia'"}
|
||||
{"text": "table: 1-10088101-1\ncolumns: No. in set, No. in series, Title, Directed by, Written by, Original air date, Production code\nQ: The episode with production code 9abx02 was originally aired on what date?\nA: SELECT Original air date FROM 1-10088101-1 WHERE Production code = '9ABX02'"}
|
||||
{"text": "table: 1-10088101-1\ncolumns: No. in set, No. in series, Title, Directed by, Written by, Original air date, Production code\nQ: What is the episode number that has production code 8abx15?\nA: SELECT MIN No. in series FROM 1-10088101-1 WHERE Production code = '8ABX15'"}
|
||||
{"text": "table: 1-10295819-2\ncolumns: Player, Highest singles ranking, Highest doubles ranking, First year played, Years played, Ties played, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L\nQ: Name the minimum tiesplayed for 6 years\nA: SELECT MIN Ties played FROM 1-10295819-2 WHERE Years played = 6"}
|
||||
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the amount of trees, that require replacement when prevailing types, % is pine \u2014 29.37 poplar \u2014 26.12 acer negundo \u2014 13.2?\nA: SELECT Amount of trees, that require replacement FROM 1-10342194-3 WHERE Prevailing types, % = 'Pine \u2014 29.37 Poplar \u2014 26.12 Acer negundo \u2014 13.2'"}
|
||||
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the amount of trees, that require replacement when district is leninsky?\nA: SELECT Amount of trees, that require replacement FROM 1-10342194-3 WHERE District = 'Leninsky'"}
|
||||
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the district when the total amount of trees is smaller than 150817.6878461314 and amount of old trees is 1,928 (1.89%)?\nA: SELECT District FROM 1-10342194-3 WHERE Total amount of trees < 150817.6878461314 AND Amount of old trees = '1,928 (1.89%)'"}
|
||||
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the amount of trees, that require replacement when the district is motovilikhinsky?\nA: SELECT Amount of trees, that require replacement FROM 1-10342194-3 WHERE District = 'Motovilikhinsky'"}
|
||||
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the total amount of trees when district is leninsky?\nA: SELECT MAX Total amount of trees FROM 1-10342194-3 WHERE District = 'Leninsky'"}
|
||||
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the district when prevailing types, % is acer negundo \u2014 30.22 tilia \u2014 18.6 poplar \u2014 15.23?\nA: SELECT District FROM 1-10342194-3 WHERE Prevailing types, % = 'Acer negundo \u2014 30.22 Tilia \u2014 18.6 Poplar \u2014 15.23'"}
|
||||
{"text": "table: 1-10429820-13\ncolumns: Iowa State vs., Overall Record, in Ames, at Opponents Venue, at Neutral Site, Last 5 Meetings, Last 10 Meetings, Current Streak, Since Beginning of Big 12\nQ: When the value of \"since beginning of big 12\" is synonymous with its' category, what are the in Ames values?\nA: SELECT in Ames FROM 1-10429820-13 WHERE Since Beginning of Big 12 = 'Since Beginning of Big 12'"}
|
||||
{"text": "table: 1-1046170-5\ncolumns: Year, Division, League, Regular Season, Playoffs, U.S. Open Cup\nQ: what's the\u00a0u.s. open cup status\u00a0for regular season\u00a0of 4th, atlantic division \nA: SELECT U.S. Open Cup FROM 1-1046170-5 WHERE Regular Season = '4th, Atlantic Division'"}
|
||||
{"text": "table: 1-1046170-5\ncolumns: Year, Division, League, Regular Season, Playoffs, U.S. Open Cup\nQ: how many division did not qualify for u.s. open cup in 2003\nA: SELECT Division FROM 1-1046170-5 WHERE U.S. Open Cup = 'Did Not Qualify' AND Year = 2003"}
|
||||
{"text": "table: 1-1046170-5\ncolumns: Year, Division, League, Regular Season, Playoffs, U.S. Open Cup\nQ: which round is u.s. open cup division semifinals\nA: SELECT U.S. Open Cup FROM 1-1046170-5 WHERE Playoffs = 'Division Semifinals'"}
|
||||
{"text": "table: 1-1046170-5\ncolumns: Year, Division, League, Regular Season, Playoffs, U.S. Open Cup\nQ: what are all the playoffs for regular season is 1st, atlantic division\nA: SELECT Playoffs FROM 1-1046170-5 WHERE Regular Season = '1st, Atlantic Division'"}
|
||||
{"text": "table: 1-1046170-5\ncolumns: Year, Division, League, Regular Season, Playoffs, U.S. Open Cup\nQ: what are all the playoffs for u.s. open cup in 1st round\nA: SELECT Playoffs FROM 1-1046170-5 WHERE U.S. Open Cup = '1st Round'"}
|
||||
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what is the total number of\u00a02nd leg\u00a0where\u00a0aggregate\u00a0is 7-2\nA: SELECT COUNT 2nd leg FROM 1-1061075-1 WHERE Aggregate = '7-2'"}
|
||||
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what's the\u00a0aggregate\u00a0where\u00a01st leg\u00a0is 3\u20132\nA: SELECT Aggregate FROM 1-1061075-1 WHERE 1st leg = '3\u20132'"}
|
||||
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what's the\u00a0competition\u00a0where\u00a0aggregate\u00a0is 4\u20137\nA: SELECT Competition FROM 1-1061075-1 WHERE Aggregate = '4\u20137'"}
|
||||
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what's the\u00a0competition\u00a0where\u00a01st leg\u00a0is 4-1 (h)\nA: SELECT Competition FROM 1-1061075-1 WHERE 1st leg = '4-1 (h)'"}
|
||||
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what is the total number of\u00a0round\u00a0where\u00a0opponents\u00a0is haugar\nA: SELECT COUNT Round FROM 1-1061075-1 WHERE Opponents = 'Haugar'"}
|
||||
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what's the\u00a01st leg\u00a0where\u00a0opponents\u00a0is galatasaray\nA: SELECT 1st leg FROM 1-1061075-1 WHERE Opponents = 'Galatasaray'"}
|
||||
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: What is the highest Rd that Tom Sneva had the pole position in?\nA: SELECT MAX Rd FROM 1-10706961-2 WHERE Pole Position = 'Tom Sneva'"}
|
||||
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: How many winning drivers were there in the race that had a fastest lap time of 56.920?\nA: SELECT COUNT Winning driver FROM 1-10706961-2 WHERE Fastest Lap = '56.920'"}
|
||||
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: How many reports are there in the race that Forsythe Racing won and Teo Fabi had the pole position in?\nA: SELECT COUNT Report FROM 1-10706961-2 WHERE Winning team = 'Forsythe Racing' AND Pole Position = 'Teo Fabi'"}
|
||||
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: Which Rd took place at the Indianapolis 500?\nA: SELECT Rd FROM 1-10706961-2 WHERE Name = 'Indianapolis 500'"}
|
||||
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: Which teams won when Bobby Rahal was their winning driver?\nA: SELECT Winning team FROM 1-10706961-2 WHERE Winning driver = 'Bobby Rahal'"}
|
||||
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: What was the fastest lap time in the Escort Radar Warning 200?\nA: SELECT Fastest Lap FROM 1-10706961-2 WHERE Name = 'Escort Radar Warning 200'"}
|
||||
{"text": "table: 1-10707176-2\ncolumns: Rnd, Race Name, Circuit, City/Location, Date, Pole position, Winning driver, Winning team, Report\nQ: What report was there for the porsche north america?\nA: SELECT Report FROM 1-10707176-2 WHERE Winning team = 'Porsche North America'"}
|
||||
{"text": "table: 1-10707176-2\ncolumns: Rnd, Race Name, Circuit, City/Location, Date, Pole position, Winning driver, Winning team, Report\nQ: What rnds were there for the phoenix international raceway?\nA: SELECT Rnd FROM 1-10707176-2 WHERE Circuit = 'Phoenix International Raceway'"}
|
||||
{"text": "table: 1-10707176-2\ncolumns: Rnd, Race Name, Circuit, City/Location, Date, Pole position, Winning driver, Winning team, Report\nQ: Who was the pole position for the rnd equalling 12?\nA: SELECT Pole position FROM 1-10707176-2 WHERE Rnd = '12'"}
|
||||
{"text": "table: 1-10707176-2\ncolumns: Rnd, Race Name, Circuit, City/Location, Date, Pole position, Winning driver, Winning team, Report\nQ: How many reports were the for the cleveland burke lakefront airport circut?\nA: SELECT COUNT Report FROM 1-10707176-2 WHERE Circuit = 'Cleveland Burke Lakefront Airport'"}
|
||||
{"text": "table: 1-10707176-2\ncolumns: Rnd, Race Name, Circuit, City/Location, Date, Pole position, Winning driver, Winning team, Report\nQ: How many winning drivers were the for the rnd equalling 5?\nA: SELECT COUNT Winning driver FROM 1-10707176-2 WHERE Rnd = '5'"}
|
||||
{"text": "table: 1-10706879-3\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: The race tony bettenhausen 200 has what smallest rd?\nA: SELECT MIN Rd FROM 1-10706879-3 WHERE Name = 'Tony Bettenhausen 200'"}
|
||||
{"text": "table: 1-10706879-3\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: The winning team of the race, los angeles times 500 is who?\nA: SELECT Winning team FROM 1-10706879-3 WHERE Name = 'Los Angeles Times 500'"}
|
||||
{"text": "table: 1-10706879-3\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: How many winning drivers in the kraco twin 125 (r2) race were there?\nA: SELECT COUNT Winning driver FROM 1-10706879-3 WHERE Name = 'Kraco Twin 125 (R2)'"}
|
||||
{"text": "table: 1-10706879-3\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: What are the races that johnny rutherford has won?\nA: SELECT Name FROM 1-10706879-3 WHERE Winning driver = 'Johnny Rutherford'"}
|
||||
{"text": "table: 1-10706879-3\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: How many fastest laps were there for a rd that equals 10?\nA: SELECT COUNT Fastest Lap FROM 1-10706879-3 WHERE Rd = 10"}
|
||||
{"text": "table: 1-10712301-5\ncolumns: Region, Operator, Licence award date, On air date, Closure date\nQ: What is the license award date for North East England?\nA: SELECT Licence award date FROM 1-10712301-5 WHERE Region = 'North East England'"}
|
||||
{"text": "table: 1-10733530-3\ncolumns: Nation, Population (thousands), Internet subscriptions (2000) (thousands of users), Internet subscriptions (2008) (thousands of users), % growth (2000\u20132008), % Internet users\nQ: What is the percentage of growth in 2000-2008 in ethiopia?\nA: SELECT % growth (2000\u20132008) FROM 1-10733530-3 WHERE Nation = 'Ethiopia'"}
|
||||
{"text": "table: 1-10733530-3\ncolumns: Nation, Population (thousands), Internet subscriptions (2000) (thousands of users), Internet subscriptions (2008) (thousands of users), % growth (2000\u20132008), % Internet users\nQ: Name the total number of percentage growth 2000-2008 of uganda?\nA: SELECT COUNT % growth (2000\u20132008) FROM 1-10733530-3 WHERE Nation = 'Uganda'"}
|
||||
{"text": "table: 1-10733530-3\ncolumns: Nation, Population (thousands), Internet subscriptions (2000) (thousands of users), Internet subscriptions (2008) (thousands of users), % growth (2000\u20132008), % Internet users\nQ: What is the maximum percentage grown 2000-2008 in burundi\nA: SELECT MAX % growth (2000\u20132008) FROM 1-10733530-3 WHERE Nation = 'Burundi'"}
|
||||
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Provide me with the names of all the villages (German) that has 76.3% of Slovenes in 1951.\nA: SELECT Village (German) FROM 1-10798421-1 WHERE Percent of Slovenes 1951 = '76.3%'"}
|
||||
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Give me the minimum number of people in 1991 with 92.5% of Slovenes in 1991.\nA: SELECT MIN Number of people 1991 FROM 1-10798421-1 WHERE Percent of Slovenes 1991 = '92.5%'"}
|
||||
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Provide me with the name of all the village (German) that are part of the village (Slovenian) with sele srednji kot. \nA: SELECT Village (German) FROM 1-10798421-1 WHERE Village (Slovenian) = 'Sele Srednji Kot'"}
|
||||
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Provide me with the name of all the village (German) that are part of the village (Slovenian) with sele borovnica.\nA: SELECT Village (German) FROM 1-10798421-1 WHERE Village (Slovenian) = 'Sele Borovnica'"}
|
||||
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Provide me with the name of the village (German) where there is 96.9% Slovenes in 1951. \nA: SELECT Village (German) FROM 1-10798421-1 WHERE Percent of Slovenes 1951 = '96.9%'"}
|
||||
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Provide with the names of the village (German) that is part of village (Slovenian) with sele srednji kot.\nA: SELECT Village (German) FROM 1-10798421-1 WHERE Village (Slovenian) = 'Sele Srednji Kot'"}
|
||||
{"text": "table: 1-10812293-3\ncolumns: Game, Date, Team, Score, High points, High rebounds, High assists, Location Attendance, Record\nQ: What was the score of the game on November 12?\nA: SELECT Score FROM 1-10812293-3 WHERE Date = 'November 12'"}
|
||||
{"text": "table: 1-10812293-3\ncolumns: Game, Date, Team, Score, High points, High rebounds, High assists, Location Attendance, Record\nQ: Who had high assists when they played against San Antonio?\nA: SELECT High assists FROM 1-10812293-3 WHERE Team = 'San Antonio'"}
|
||||
{"text": "table: 1-10812293-3\ncolumns: Game, Date, Team, Score, High points, High rebounds, High assists, Location Attendance, Record\nQ: Who scored the most points in game 4?\nA: SELECT High points FROM 1-10812293-3 WHERE Game = 4"}
|
||||
{"text": "table: 1-10812293-3\ncolumns: Game, Date, Team, Score, High points, High rebounds, High assists, Location Attendance, Record\nQ: Where was the game on November 20?\nA: SELECT Location Attendance FROM 1-10812293-3 WHERE Date = 'November 20'"}
|
||||
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: The canadian airdate of 11 february 2008 applied to what series number?\nA: SELECT COUNT No. in series FROM 1-10935205-1 WHERE Canadian airdate = '11 February 2008'"}
|
||||
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: The U.S. airdate of 4 april 2008 had a production code of what?\nA: SELECT MAX Production code FROM 1-10935205-1 WHERE US airdate = '4 April 2008'"}
|
||||
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: The episode titled \"don't stop believin'\" was what highest number of the season?\nA: SELECT MAX No. in season FROM 1-10935205-1 WHERE Title = '\"Don't Stop Believin'\"'"}
|
||||
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: The U.S. airdate of 8 august 2008 also had canadian airdates of what?\nA: SELECT Canadian airdate FROM 1-10935205-1 WHERE US airdate = '8 August 2008'"}
|
||||
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: The canadian airdate of 17 march 2008 had how many numbers in the season?\nA: SELECT COUNT No. in season FROM 1-10935205-1 WHERE Canadian airdate = '17 March 2008'"}
|
||||
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: For the episode(s) aired in the U.S. on 4 april 2008, what were the names?\nA: SELECT Title FROM 1-10935205-1 WHERE US airdate = '4 April 2008'"}
|
||||
{"text": "table: 1-10953197-5\ncolumns: No. in series, No. in season, Title, Director, Writer(s), Original air date, Production code\nQ: Who directed the episode \"Great Sexpectations (2)\"?\nA: SELECT Director FROM 1-10953197-5 WHERE Title = '\"Great Sexpectations (2)\"'"}
|
||||
{"text": "table: 1-10975034-2\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: Which player from the 2004 CFL draft attended Wilfrid Laurier?\nA: SELECT Player FROM 1-10975034-2 WHERE College = 'Wilfrid Laurier'"}
|
||||
{"text": "table: 1-10975034-2\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: What position does Christian Leibl-Cote play?\nA: SELECT Position FROM 1-10975034-2 WHERE Player = 'Christian Leibl-Cote'"}
|
||||
{"text": "table: 1-10975034-2\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: What is the pick number for Northwestern college?\nA: SELECT MAX Pick # FROM 1-10975034-2 WHERE College = 'Northwestern'"}
|
||||
{"text": "table: 1-10992-3\ncolumns: No, City district (Stadtteil), Area in km\u00b2, Population, Foreign nationals, Foreign nationals in %, Area district (Ortsbezirk)\nQ: How many foreigners in percentage terms had a population of 4.911?\nA: SELECT COUNT Foreign nationals in % FROM 1-10992-3 WHERE Population = '4.911'"}
|
||||
{"text": "table: 1-10992-3\ncolumns: No, City district (Stadtteil), Area in km\u00b2, Population, Foreign nationals, Foreign nationals in %, Area district (Ortsbezirk)\nQ: What is the number of the city district of stadtteil where foreigners are 5.162?\nA: SELECT COUNT City district (Stadtteil) FROM 1-10992-3 WHERE Foreign nationals = '5.162'"}
|
||||
{"text": "table: 1-10992-3\ncolumns: No, City district (Stadtteil), Area in km\u00b2, Population, Foreign nationals, Foreign nationals in %, Area district (Ortsbezirk)\nQ: What is the city where the number is 47?\nA: SELECT City district (Stadtteil) FROM 1-10992-3 WHERE No = '47'"}
|
||||
{"text": "table: 1-11044765-1\ncolumns: School, Mascot, Location, League, Enrollment\nQ: Which leagues have Raiders as their mascot?\nA: SELECT League FROM 1-11044765-1 WHERE Mascot = 'Raiders'"}
|
||||
{"text": "table: 1-11044765-1\ncolumns: School, Mascot, Location, League, Enrollment\nQ: Which leagues is the Galena school in?\nA: SELECT League FROM 1-11044765-1 WHERE School = 'Galena'"}
|
||||
{"text": "table: 1-11044765-1\ncolumns: School, Mascot, Location, League, Enrollment\nQ: What city and state is the Lancers mascot located?\nA: SELECT Location FROM 1-11044765-1 WHERE Mascot = 'Lancers'"}
|
||||
{"text": "table: 1-11044765-1\ncolumns: School, Mascot, Location, League, Enrollment\nQ: What city and state are the miners located in?\nA: SELECT Location FROM 1-11044765-1 WHERE Mascot = 'Miners'"}
|
||||
{"text": "table: 1-11044765-1\ncolumns: School, Mascot, Location, League, Enrollment\nQ: Which school has the Raiders as their mascot?\nA: SELECT School FROM 1-11044765-1 WHERE Mascot = 'Raiders'"}
|
||||
{"text": "table: 1-1121352-2\ncolumns: No., Date, Tournament, Winning score, To par, Margin of victory, Runner(s)-up\nQ: Where was the tournament dated nov 3, 2002?\nA: SELECT Tournament FROM 1-1121352-2 WHERE Date = 'Nov 3, 2002'"}
|
||||
{"text": "table: 1-1121352-2\ncolumns: No., Date, Tournament, Winning score, To par, Margin of victory, Runner(s)-up\nQ: Where is the margin of victory dated mar 28, 2004?\nA: SELECT Margin of victory FROM 1-1121352-2 WHERE Date = 'Mar 28, 2004'"}
|
||||
{"text": "table: 1-1121352-2\ncolumns: No., Date, Tournament, Winning score, To par, Margin of victory, Runner(s)-up\nQ: What is the to par dated may 4, 2003?\nA: SELECT To par FROM 1-1121352-2 WHERE Date = 'May 4, 2003'"}
|
||||
{"text": "table: 1-1121352-2\ncolumns: No., Date, Tournament, Winning score, To par, Margin of victory, Runner(s)-up\nQ: What date were the runner ups pat hurst juli inkster?\nA: SELECT Date FROM 1-1121352-2 WHERE Runner(s)-up = 'Pat Hurst Juli Inkster'"}
|
||||
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what's the total number of\u00a0final epbeingode count\u00a0with\u00a0character\u00a0being rick stetler\nA: SELECT COUNT Final Episode Count FROM 1-11210576-4 WHERE Character = 'Rick Stetler'"}
|
||||
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what are all the actor where first episode is \"ambush\"\nA: SELECT Actor FROM 1-11210576-4 WHERE First Episode = '\"Ambush\"'"}
|
||||
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what's the\u00a0character\u00a0with\u00a0fate\u00a0being deceased: knife wound\nA: SELECT Character FROM 1-11210576-4 WHERE Fate = 'Deceased: Knife Wound'"}
|
||||
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what's the total number of\u00a0final epbeingode count\u00a0with\u00a0first epbeingode\u00a0being \"l.a.\"\nA: SELECT COUNT Final Episode Count FROM 1-11210576-4 WHERE First Episode = '\"L.A.\"'"}
|
||||
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what's the\u00a0actor\u00a0with\u00a0character\u00a0being judge joseph ratner\nA: SELECT Actor FROM 1-11210576-4 WHERE Character = 'Judge Joseph Ratner'"}
|
||||
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what's the\u00a0first epbeingode\u00a0with\u00a0final epbeingode\u00a0being \"rio\"\nA: SELECT First Episode FROM 1-11210576-4 WHERE Final Episode = '\"Rio\"'"}
|
||||
{"text": "table: 1-11214772-2\ncolumns: Year, Champion, Score, Runner-Up, Location, Semi-Finalist #1, Semi-Finalist #2\nQ: Which team was the second semi finalist in 2007?\nA: SELECT Semi-Finalist #2 FROM 1-11214772-2 WHERE Year = 2007"}
|
||||
{"text": "table: 1-11214772-2\ncolumns: Year, Champion, Score, Runner-Up, Location, Semi-Finalist #1, Semi-Finalist #2\nQ: How many teams were listed as runner up in 2005 and there the first semi finalist was Western Carolina?\nA: SELECT COUNT Runner-Up FROM 1-11214772-2 WHERE Semi-Finalist #1 = 'Western Carolina' AND Year = 2005"}
|
||||
119
Data/lora/wikisql.py
Normal file
119
Data/lora/wikisql.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
"""
|
||||
Code to preprocess the WikiSQL dataset adapted from
|
||||
https://github.com/salesforce/WikiSQL and
|
||||
https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb .
|
||||
"""
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def load():
|
||||
"""
|
||||
Load all three splits of the WikiSQL dataset.
|
||||
"""
|
||||
return (WikiSQL(dn) for dn in ["train", "dev", "test"])
|
||||
|
||||
|
||||
class WikiSQL:
|
||||
def __init__(self, dataset, save_dir="/tmp"):
|
||||
valid_sets = ("train", "dev", "test")
|
||||
if dataset not in valid_sets:
|
||||
raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}")
|
||||
data_dir = os.path.join(save_dir, "wikisql")
|
||||
self._maybe_download(data_dir)
|
||||
|
||||
self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl"))
|
||||
self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl"))
|
||||
|
||||
def _maybe_download(self, data_dir):
|
||||
if not os.path.exists(data_dir):
|
||||
import io
|
||||
import tarfile
|
||||
from urllib import request
|
||||
|
||||
url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
|
||||
r = request.urlopen(url)
|
||||
with tarfile.open(fileobj=io.BytesIO(r.read())) as tf:
|
||||
tf.extractall(data_dir)
|
||||
|
||||
def _parse_tables(self, tables):
|
||||
self._tables = {}
|
||||
with open(tables) as f:
|
||||
for line in f:
|
||||
table = json.loads(line)
|
||||
self._tables[table["id"]] = {
|
||||
"columns": table["header"],
|
||||
"types": table["types"],
|
||||
"desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}",
|
||||
}
|
||||
|
||||
def _parse_queries(self, queries):
|
||||
self._queries = []
|
||||
with open(queries) as f:
|
||||
for line in f:
|
||||
query = json.loads(line)
|
||||
table = self._tables[query["table_id"]]
|
||||
question = query["question"]
|
||||
answer = self.query_to_text(
|
||||
query["sql"], query["table_id"], table["columns"], table["types"]
|
||||
)
|
||||
self._queries.append(
|
||||
f"<s>{table['desc']}\nQ: {question}\nA: {answer}</s>"
|
||||
)
|
||||
|
||||
def query_to_text(self, query, table, columns, types):
|
||||
aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
|
||||
condition_ops = ["=", ">", "<", "OP"]
|
||||
column = columns[query["sel"]]
|
||||
aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else ""
|
||||
sql = f"SELECT {aggregation}{column} FROM {table}"
|
||||
|
||||
conditions = query["conds"]
|
||||
if conditions:
|
||||
cs = []
|
||||
for i, o, v in conditions:
|
||||
column = columns[i]
|
||||
op = condition_ops[o]
|
||||
|
||||
if types[i] == "text":
|
||||
value = f"'{v}'"
|
||||
else:
|
||||
value = v
|
||||
cs.append(f"{column} {op} {value}")
|
||||
|
||||
sql += " WHERE " + " AND ".join(cs)
|
||||
|
||||
return sql
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._queries[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._queries)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
datanames = ["train", "dev", "test"]
|
||||
sizes = [56355, 8421, 15878]
|
||||
for dataname, size in zip(datanames, sizes):
|
||||
len(WikiSQL(dataname)) == size, f"Wrong {dataname} set size."
|
||||
|
||||
# Write the sets to jsonl
|
||||
import json
|
||||
|
||||
train, dev, test = load()
|
||||
datasets = [
|
||||
(train, "train", 1000),
|
||||
(dev, "valid", 100),
|
||||
(test, "test", 100),
|
||||
]
|
||||
for dataset, name, size in datasets:
|
||||
with open(f"data/{name}.jsonl", "w") as fid:
|
||||
for e, t in zip(range(size), dataset):
|
||||
# Strip the <s>, </s> since the tokenizer adds them
|
||||
json.dump({"text": t[3:-4]}, fid)
|
||||
fid.write("\n")
|
||||
@@ -236,3 +236,11 @@ public struct CohereConfiguration: Codable {
|
||||
Float.self, forKey: CohereConfiguration.CodingKeys.logitScale)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension CohereModel: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,3 +254,11 @@ public struct GemmaConfiguration: Codable {
|
||||
Bool.self, forKey: CodingKeys.ropeTraditional) ?? false
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension GemmaModel: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,3 +253,11 @@ public struct LlamaConfiguration: Codable {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension LlamaModel: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,47 +17,64 @@ public func load(
|
||||
hub: HubApi = HubApi(), configuration: ModelConfiguration,
|
||||
progressHandler: @escaping (Progress) -> Void = { _ in }
|
||||
) async throws -> (LLMModel, Tokenizer) {
|
||||
// note: this doesn't have a way to pass the HubApi
|
||||
let tokenizer = try await loadTokenizer(configuration: configuration)
|
||||
do {
|
||||
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
|
||||
|
||||
// download the model weights and config
|
||||
let repo = Hub.Repo(id: configuration.id)
|
||||
let modelFiles = ["config.json", "*.safetensors"]
|
||||
let modelDirectory = try await hub.snapshot(
|
||||
from: repo, matching: modelFiles, progressHandler: progressHandler)
|
||||
let modelDirectory: URL
|
||||
|
||||
// create the model (no weights loaded)
|
||||
let configurationURL = modelDirectory.appending(component: "config.json")
|
||||
let baseConfig = try JSONDecoder().decode(
|
||||
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
|
||||
switch configuration.id {
|
||||
case .id(let id):
|
||||
// download the model weights and config
|
||||
let repo = Hub.Repo(id: id)
|
||||
let modelFiles = ["config.json", "*.safetensors"]
|
||||
modelDirectory = try await hub.snapshot(
|
||||
from: repo, matching: modelFiles, progressHandler: progressHandler)
|
||||
|
||||
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
|
||||
case .directory(let directory):
|
||||
modelDirectory = directory
|
||||
}
|
||||
|
||||
// load the weights
|
||||
var weights = [String: MLXArray]()
|
||||
let enumerator = FileManager.default.enumerator(
|
||||
at: modelDirectory, includingPropertiesForKeys: nil)!
|
||||
for case let url as URL in enumerator {
|
||||
if url.pathExtension == "safetensors" {
|
||||
let w = try loadArrays(url: url)
|
||||
for (key, value) in w {
|
||||
weights[key] = value
|
||||
// create the model (no weights loaded)
|
||||
let configurationURL = modelDirectory.appending(component: "config.json")
|
||||
let baseConfig = try JSONDecoder().decode(
|
||||
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
|
||||
|
||||
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
|
||||
|
||||
// load the weights
|
||||
var weights = [String: MLXArray]()
|
||||
let enumerator = FileManager.default.enumerator(
|
||||
at: modelDirectory, includingPropertiesForKeys: nil)!
|
||||
for case let url as URL in enumerator {
|
||||
if url.pathExtension == "safetensors" {
|
||||
let w = try loadArrays(url: url)
|
||||
for (key, value) in w {
|
||||
weights[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// quantize if needed
|
||||
if let quantization = baseConfig.quantization {
|
||||
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
|
||||
}
|
||||
|
||||
// apply the loaded weights
|
||||
let parameters = ModuleParameters.unflattened(weights)
|
||||
try model.update(parameters: parameters, verify: [.all])
|
||||
|
||||
eval(model)
|
||||
|
||||
return (model, tokenizer)
|
||||
|
||||
} catch Hub.HubClientError.authorizationRequired {
|
||||
// an authorizationRequired means (typically) that the named repo doesn't exist on
|
||||
// on the server so retry with local only configuration
|
||||
var newConfiguration = configuration
|
||||
newConfiguration.id = .directory(configuration.modelDirectory(hub: hub))
|
||||
return try await load(
|
||||
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
|
||||
}
|
||||
|
||||
// quantize if needed
|
||||
if let quantization = baseConfig.quantization {
|
||||
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
|
||||
}
|
||||
|
||||
// apply the loaded weights
|
||||
let parameters = ModuleParameters.unflattened(weights)
|
||||
try model.update(parameters: parameters, verify: [.all])
|
||||
|
||||
eval(model)
|
||||
|
||||
return (model, tokenizer)
|
||||
}
|
||||
|
||||
// MARK: - Quantization
|
||||
|
||||
61
Libraries/LLM/Lora+Data.swift
Normal file
61
Libraries/LLM/Lora+Data.swift
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
|
||||
enum LoRADataError: Error {
|
||||
case fileNotFound(URL, String)
|
||||
}
|
||||
|
||||
/// Load a LoRA data file.
|
||||
///
|
||||
/// Given a directory and a base name, e.g. `train`, this will load a `.jsonl` or `.txt` file
|
||||
/// if possible.
|
||||
public func loadLoRAData(directory: URL, name: String) throws -> [String] {
|
||||
let extensions = ["jsonl", "txt"]
|
||||
|
||||
for ext in extensions {
|
||||
let url = directory.appending(component: "\(name).\(ext)")
|
||||
if FileManager.default.fileExists(atPath: url.path()) {
|
||||
return try loadLoRAData(url: url)
|
||||
}
|
||||
}
|
||||
|
||||
throw LoRADataError.fileNotFound(directory, name)
|
||||
}
|
||||
|
||||
/// Load a .txt or .jsonl file and return the contents
|
||||
public func loadLoRAData(url: URL) throws -> [String] {
|
||||
switch url.pathExtension {
|
||||
case "jsonl":
|
||||
return try loadJSONL(url: url)
|
||||
|
||||
case "txt":
|
||||
return try loadLines(url: url)
|
||||
|
||||
default:
|
||||
fatalError("Unable to load data file, unknown type: \(url)")
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func loadJSONL(url: URL) throws -> [String] {
|
||||
|
||||
struct Line: Codable {
|
||||
let text: String?
|
||||
}
|
||||
|
||||
return try String(contentsOf: url)
|
||||
.components(separatedBy: .newlines)
|
||||
.filter {
|
||||
$0.first == "{"
|
||||
}
|
||||
.compactMap {
|
||||
try JSONDecoder().decode(Line.self, from: $0.data(using: .utf8)!).text
|
||||
}
|
||||
}
|
||||
|
||||
func loadLines(url: URL) throws -> [String] {
|
||||
try String(contentsOf: url)
|
||||
.components(separatedBy: .newlines)
|
||||
.filter { !$0.isEmpty }
|
||||
}
|
||||
639
Libraries/LLM/Lora.swift
Normal file
639
Libraries/LLM/Lora.swift
Normal file
@@ -0,0 +1,639 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
import MLX
|
||||
import MLXNN
|
||||
import MLXOptimizers
|
||||
import MLXRandom
|
||||
import Tokenizers
|
||||
|
||||
/// Layers to apply LoRA adapters to.
|
||||
///
|
||||
/// This is the value returned by ``LoRAModel/loraLinearLayers()``.
|
||||
public typealias LoRALinearLayers = [(Module, [String])]
|
||||
|
||||
public protocol LoRAModel {
|
||||
/// Return the layers and keys to apply LoRA adapters to.
|
||||
///
|
||||
/// For example this might apply the adapters to the `q` an `v` projections in the
|
||||
/// Attention layers:
|
||||
///
|
||||
/// ```swift
|
||||
/// model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
/// ```
|
||||
///
|
||||
/// It is not required that a model implement this protocol to have LoRA adapters applied, but
|
||||
/// the command line driver example uses this to produce the ``LoRALinearLayers``.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``LoRATrain/convert(model:layers:)``
|
||||
func loraLinearLayers() -> LoRALinearLayers
|
||||
}
|
||||
|
||||
/// Protocol for LoRA implementations that provides a method for converting back to a `Linear`
|
||||
/// (or subtype).
|
||||
///
|
||||
/// This is normally called via ``LoRATrain/fuse(model:layers:deQuantize:)``
|
||||
public protocol LoRAConvertToLinear {
|
||||
func toLinear(deQuantize: Bool) -> Linear
|
||||
}
|
||||
|
||||
/// Implementation of LoRA `Linear` replacement layer.
|
||||
///
|
||||
/// This layer implements the LoRA capabilities for `Linear` layers, specifically:
|
||||
///
|
||||
/// - converting `Linear` or `QuantizedLinear` layers to ``LoRALinear`` / ``QLoRALinear``
|
||||
/// - converting ``LoRALinear`` back to `Linear` or `QuantizedLinear` (``LoRAConvertToLinear``)
|
||||
/// - implementing the LoRA evaluation
|
||||
///
|
||||
/// ``QLoRALinear`` is the equivalent class for `QuantizedLinear`.
|
||||
///
|
||||
/// This is not typically used directly -- ``LoRATrain/convert(model:layers:)`` is used to
|
||||
/// add the adapter layers to a given model.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)
|
||||
/// - [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||
/// - ``QLoRALinear``
|
||||
/// - ``LoRATrain/convert(model:layers:)``
|
||||
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
|
||||
public class LoRALinear: Linear, LoRAConvertToLinear {
|
||||
|
||||
let scale: Float
|
||||
|
||||
@ParameterInfo(key: "lora_a") var loraA: MLXArray
|
||||
@ParameterInfo(key: "lora_b") var loraB: MLXArray
|
||||
|
||||
required public init(
|
||||
_ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false,
|
||||
scale: Float = 20.0, linear: Linear
|
||||
) {
|
||||
// Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
// Low rank lora weights
|
||||
let loraScale = 1 / sqrt(Float(inputDimensions))
|
||||
self._loraA.wrappedValue = MLXRandom.uniform(
|
||||
low: -loraScale, high: loraScale, [inputDimensions, rank])
|
||||
self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions])
|
||||
|
||||
super.init(weight: linear.weight, bias: linear.bias)
|
||||
|
||||
freeze()
|
||||
}
|
||||
|
||||
/// Freeze all parameters except the lora parameters
|
||||
public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false)
|
||||
throws
|
||||
{
|
||||
// realize the keys and omit the lora parameters
|
||||
let keys =
|
||||
(keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 })
|
||||
.filter {
|
||||
$0 != "lora_a" && $0 != "lora_b"
|
||||
}
|
||||
try super.freeze(recursive: recursive, keys: keys, strict: strict)
|
||||
}
|
||||
|
||||
/// Convert a `Linear` or `QuantizedLinear` layer into a new `Linear` layer
|
||||
/// that implements the `LoRA` adapter.
|
||||
///
|
||||
/// This is typically called via ``LoRATrain/convert(model:layers:)``.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``LoRATrain/convert(model:layers:)``
|
||||
/// - ``QLoRALinear/from(linear:rank:)``
|
||||
public static func from(linear: Linear, rank: Int = 8) -> Linear {
|
||||
if let linear = linear as? QuantizedLinear {
|
||||
return QLoRALinear.from(linear: linear, rank: rank)
|
||||
}
|
||||
let (outputDimensions, inputDimensions) = linear.shape
|
||||
return LoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear)
|
||||
}
|
||||
|
||||
/// Convert back into a fused `Linear` layer.
|
||||
///
|
||||
/// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
|
||||
/// - ``LoRAConvertToLinear``
|
||||
/// - ``QLoRALinear/toLinear(deQuantize:)``
|
||||
public func toLinear(deQuantize: Bool = false) -> Linear {
|
||||
let dtype = weight.dtype
|
||||
let loraB = (scale * loraB.T).asType(dtype)
|
||||
let loraA = loraA.T.asType(dtype)
|
||||
return Linear(weight: weight + matmul(loraB, loraA), bias: bias)
|
||||
}
|
||||
|
||||
public override func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
let y = super.callAsFunction(x.asType(weight.dtype))
|
||||
let z = matmul(matmul(x, self.loraA), self.loraB)
|
||||
return y + scale * z
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of LoRA `QuantizedLinear` replacement layer.
|
||||
///
|
||||
/// See ``LoRALinear`` (equivalent class for `Linear` layers) for more information.
|
||||
public class QLoRALinear: QuantizedLinear, LoRAConvertToLinear {
|
||||
|
||||
let scale: Float
|
||||
|
||||
@ParameterInfo(key: "lora_a") var loraA: MLXArray
|
||||
@ParameterInfo(key: "lora_b") var loraB: MLXArray
|
||||
|
||||
required public init(
|
||||
_ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false,
|
||||
scale: Float = 20.0, linear: QuantizedLinear
|
||||
) {
|
||||
|
||||
// Scale for low-rank update
|
||||
self.scale = scale
|
||||
|
||||
// Low rank lora weights
|
||||
let loraScale = 1 / sqrt(Float(inputDimensions))
|
||||
self._loraA.wrappedValue = MLXRandom.uniform(
|
||||
low: -loraScale, high: loraScale, [inputDimensions, rank])
|
||||
self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions])
|
||||
|
||||
super.init(
|
||||
weight: linear.weight, bias: linear.bias, scales: linear.scales, biases: linear.biases,
|
||||
groupSize: linear.groupSize, bits: linear.bits)
|
||||
|
||||
// start frozen except for the lora keys
|
||||
freeze()
|
||||
}
|
||||
|
||||
/// Freeze all parameters except the lora parameters
|
||||
public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false)
|
||||
throws
|
||||
{
|
||||
// realize the keys and omit the lora parameters
|
||||
let keys =
|
||||
(keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 })
|
||||
.filter {
|
||||
$0 != "lora_a" && $0 != "lora_b"
|
||||
}
|
||||
try super.freeze(recursive: recursive, keys: keys, strict: strict)
|
||||
}
|
||||
|
||||
/// Convert a `QuantizedLinear` layer into a new `Linear` layer
|
||||
/// that implements the `LoRA` adapter.
|
||||
///
|
||||
/// This is typically called via ``LoRATrain/convert(model:layers:)``.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``LoRATrain/convert(model:layers:)``
|
||||
/// - ``LoRALinear/from(linear:rank:)``
|
||||
public static func from(linear: QuantizedLinear, rank: Int = 8) -> Linear {
|
||||
var (outputDimensions, inputDimensions) = linear.shape
|
||||
inputDimensions = inputDimensions * 32 / linear.bits
|
||||
return QLoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear)
|
||||
}
|
||||
|
||||
/// Convert back into a fused `QuantizedLinear` layer.
|
||||
///
|
||||
/// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
|
||||
public func toLinear(deQuantize: Bool = false) -> Linear {
|
||||
// convert back into full weights
|
||||
let weight = dequantized(
|
||||
weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits)
|
||||
|
||||
let loraB = (scale * loraB.T).asType(.float16)
|
||||
let loraA = loraA.T.asType(.float16)
|
||||
|
||||
// convert back into quantized
|
||||
return QuantizedLinear(
|
||||
weight: weight + matmul(loraB, loraA), bias: bias, groupSize: groupSize, bits: bits)
|
||||
}
|
||||
|
||||
public override func callAsFunction(_ x: MLXArray) -> MLXArray {
|
||||
let y = super.callAsFunction(x.asType(scales.dtype))
|
||||
let z = matmul(matmul(x, self.loraA), self.loraB)
|
||||
return y + scale * z
|
||||
}
|
||||
}
|
||||
|
||||
/// Equivalent to `lora.py/iterate_batches()`. Used internally by ``LoRATrain``.
|
||||
struct LoRABatchIterator: Sequence, IteratorProtocol {
|
||||
|
||||
let dataset: [String]
|
||||
let batchSize: Int
|
||||
let tokenizer: Tokenizer
|
||||
|
||||
let train: Bool
|
||||
|
||||
var indices: [Int]
|
||||
var index = 0
|
||||
|
||||
public init(dataset: [String], tokenizer: Tokenizer, batchSize: Int, train: Bool) {
|
||||
self.dataset = dataset
|
||||
self.batchSize = batchSize
|
||||
self.tokenizer = tokenizer
|
||||
self.train = train
|
||||
|
||||
self.indices = Array(0 ..< dataset.count)
|
||||
if train {
|
||||
indices.shuffle()
|
||||
}
|
||||
}
|
||||
|
||||
mutating public func next() -> (MLXArray, MLXArray, MLXArray)? {
|
||||
if index >= indices.count {
|
||||
if !train {
|
||||
return nil
|
||||
}
|
||||
|
||||
indices.shuffle()
|
||||
index = 0
|
||||
}
|
||||
|
||||
let endIndex = Swift.min(index + batchSize, indices.count)
|
||||
|
||||
let batch = (index ..< endIndex)
|
||||
.map { tokenizer.encode(text: dataset[indices[$0]]) }
|
||||
let lengths = batch.map { $0.count }
|
||||
let maxLength = lengths.max() ?? 0
|
||||
|
||||
if maxLength > 2048 {
|
||||
print(
|
||||
"""
|
||||
[WARNING] Some sequences are longer than 2048 tokens.
|
||||
Consider pre-splitting your data to save memory.
|
||||
""")
|
||||
}
|
||||
|
||||
// pad to the max length
|
||||
let batchArray = MLXArray.zeros([lengths.count, maxLength], type: Int32.self)
|
||||
for (j, (b, l)) in zip(batch, lengths).enumerated() {
|
||||
batchArray[j, 0 ..< l] = MLXArray(b)
|
||||
}
|
||||
|
||||
index = endIndex
|
||||
|
||||
return (batchArray[0..., .stride(to: -1)], batchArray[0..., 1...], MLXArray(lengths))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Collection of functions for adding LoRA adapters to an LLM model, training, fusing and saving/loading weights.
|
||||
///
|
||||
/// The typical flow for training is:
|
||||
///
|
||||
/// ```swift
|
||||
/// // load the base model and tokenizer
|
||||
/// let (model, tokenizer) = try await LLM.load(configuration: ModelConfiguration.mistral7B4bit)
|
||||
///
|
||||
/// // add LoRALinear adapter layers
|
||||
/// LoRATrain.convert(model: model, layers: Array(model.loraLinearLayers().suffix(4)))
|
||||
///
|
||||
/// // optionally load LoRA weights
|
||||
/// try LoRATrain.loadLoRAWeights(model: model, url: ...)
|
||||
///
|
||||
/// // load the train/validation data
|
||||
/// let train = try loadLoRAData(directory: data, name: "train")
|
||||
/// let valid = try loadLoRAData(directory: data, name: "valid")
|
||||
///
|
||||
/// // train
|
||||
/// let optimizer = Adam(learningRate: 1e-5)
|
||||
/// try await LoRATrain.train(
|
||||
/// model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
|
||||
/// parameters: LoRATrain.Parameters()
|
||||
/// ) { progress in
|
||||
/// print(progress)
|
||||
/// return .more
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// At this point the model will be trained and you could do one of the following:
|
||||
///
|
||||
/// - ``saveLoRAWeights(model:url:)`` -- write the LoRA weights to a file
|
||||
/// - ``fuse(model:layers:deQuantize:)`` -- fuse the LoRA weights and convert back into the original model
|
||||
/// architecture. These weights can be saved and reloaded with normal model handling code.
|
||||
/// - ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``-- compute the test loss
|
||||
/// againts a test dataset
|
||||
/// - use the in memory model as a normal `LLMModel` and evaluate a prompt
|
||||
///
|
||||
public enum LoRATrain {
|
||||
|
||||
public typealias LoraLossFunction = (Module, MLXArray, MLXArray, MLXArray) -> (
|
||||
MLXArray, MLXArray
|
||||
)
|
||||
|
||||
/// LoRA training parameters
|
||||
public struct Parameters {
|
||||
/// number of prompts to evaluate per iteration
|
||||
public var batchSize = 4
|
||||
|
||||
/// number of iterations to train for
|
||||
public var iterations = 1000
|
||||
|
||||
/// number of training steps between loss reporting
|
||||
public var stepsPerReport = 10
|
||||
|
||||
/// number of steps between validations
|
||||
public var stepsPerEval = 100
|
||||
|
||||
/// number of validations batches, `0` uses the entire validation set
|
||||
public var validationBatches = 10
|
||||
|
||||
/// save the model every N iterations
|
||||
public var saveEvery = 100
|
||||
|
||||
/// save path for the adapter `.safetensors`
|
||||
public var adapterURL: URL?
|
||||
|
||||
public init(
|
||||
batchSize: Int = 4, iterations: Int = 1000, stepsPerReport: Int = 10,
|
||||
stepsPerEval: Int = 100, validationBatches: Int = 10, saveEvery: Int = 100,
|
||||
adapterURL: URL? = nil
|
||||
) {
|
||||
self.batchSize = batchSize
|
||||
self.iterations = iterations
|
||||
self.stepsPerReport = stepsPerReport
|
||||
self.stepsPerEval = stepsPerEval
|
||||
self.validationBatches = validationBatches
|
||||
self.saveEvery = saveEvery
|
||||
self.adapterURL = adapterURL
|
||||
}
|
||||
}
|
||||
|
||||
/// Freeze the model layers and replace the indicated modules (Linear) that should be
|
||||
/// converted to ``LoRALinear`` and remain trainable.
|
||||
///
|
||||
/// Once a model has had the LoRA adapters applied, adapter weights can be loaded
|
||||
/// (if available):
|
||||
///
|
||||
/// ```swift
|
||||
/// try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
|
||||
/// ```
|
||||
///
|
||||
/// At this point the model is ready for one or more of the following:
|
||||
///
|
||||
/// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)``
|
||||
/// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``
|
||||
/// - fusing with ``fuse(model:layers:deQuantize:)``
|
||||
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:didGenerate:)``
|
||||
/// - note that this is just using normal model text generation
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: model to convert
|
||||
/// - layers: number of suffix layers to convert
|
||||
public static func convert(model: Module, layers: LoRALinearLayers) {
|
||||
model.freeze()
|
||||
|
||||
for (layer, keys) in layers {
|
||||
var update = ModuleChildren()
|
||||
let children = layer.children()
|
||||
for key in keys {
|
||||
if let item = children[key], case .value(let child) = item {
|
||||
if let linear = child as? Linear {
|
||||
update[key] = .value(LoRALinear.from(linear: linear))
|
||||
} else {
|
||||
print("\(key) on \(layer) is not Linear")
|
||||
}
|
||||
} else {
|
||||
print("failed to find key \(key) on \(layer)")
|
||||
}
|
||||
}
|
||||
layer.update(modules: update)
|
||||
}
|
||||
}
|
||||
|
||||
/// Fuses the LoRA adapters back into the model weights.
|
||||
///
|
||||
/// This produces a model in the original format with `Linear` or `QuantizedLinear` layer
|
||||
/// weights that incorporate the LoRA adapter.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: model to convert
|
||||
/// - deQuantize: if `true` will convert `QuantizedLinear` back into `Linear`
|
||||
public static func fuse(model: Module, layers: LoRALinearLayers, deQuantize: Bool = false) {
|
||||
for (layer, keys) in layers {
|
||||
var update = ModuleChildren()
|
||||
let children = layer.children()
|
||||
for key in keys {
|
||||
if let item = children[key], case .value(let child) = item {
|
||||
if let lora = child as? LoRAConvertToLinear {
|
||||
update[key] = .value(lora.toLinear(deQuantize: deQuantize))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !update.isEmpty {
|
||||
layer.update(modules: update)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static func loss(model: Module, inputs: MLXArray, targets: MLXArray, lengths: MLXArray)
|
||||
-> (
|
||||
MLXArray, MLXArray
|
||||
)
|
||||
{
|
||||
// def loss(model, inputs, targets, lengths):
|
||||
|
||||
// run model on inputs
|
||||
let model = model as! LLMModel
|
||||
let logits = model(inputs, cache: nil).0.asType(.float32)
|
||||
|
||||
// mask padding tokens
|
||||
let lengthMask = MLXArray(0 ..< inputs.dim(1))[.newAxis, 0...] .< lengths[0..., .newAxis]
|
||||
|
||||
// calculate the loss
|
||||
let ntoks = lengthMask.sum()
|
||||
let ce = (crossEntropy(logits: logits, targets: targets) * lengthMask).sum() / ntoks
|
||||
return (ce, ntoks)
|
||||
}
|
||||
|
||||
/// Evaluate the model and dataset and return the loss over the entire dataset.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: the model to evaluate
|
||||
/// - dataset: the dataset
|
||||
/// - loss: loss function
|
||||
/// - tokenizer: tokenizer
|
||||
/// - batchSize: number of items from the dataset to evaluate at once
|
||||
/// - batchCount: number of batch elements to evaluate, 0 for all
|
||||
/// - Returns: the loss over the enumerate data
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``loadLoRAData(directory:name:)``
|
||||
public static func evaluate(
|
||||
model: Module, dataset: [String], loss: LoraLossFunction = loss, tokenizer: Tokenizer,
|
||||
batchSize: Int, batchCount: Int
|
||||
) -> Float {
|
||||
var allLosses = [Float]()
|
||||
var tokenCount = 0
|
||||
|
||||
for (iteration, (inputs, targets, lengths)) in LoRABatchIterator(
|
||||
dataset: dataset, tokenizer: tokenizer, batchSize: batchSize, train: false
|
||||
).enumerated() {
|
||||
let (losses, tokens) = loss(model, inputs, targets, lengths)
|
||||
allLosses.append((losses * tokens).item(Float.self))
|
||||
tokenCount += tokens.item(Int.self)
|
||||
|
||||
if batchCount != 0 && iteration + 1 >= batchCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return (sum(MLXArray(allLosses), stream: .cpu) / tokenCount).item(Float.self)
|
||||
}
|
||||
|
||||
/// Given a model with LoRA adaptors applied, load adapter weights from a `.safetensors` file.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``convert(model:layers:)``
|
||||
/// - ``saveLoRAWeights(model:url:)``
|
||||
public static func loadLoRAWeights(model: Module, url: URL) throws {
|
||||
let weights = try ModuleParameters.unflattened(loadArrays(url: url))
|
||||
try model.update(parameters: weights, verify: .noUnusedKeys)
|
||||
eval(model)
|
||||
}
|
||||
|
||||
/// Given a model with LoRA adaptors applied, write adapter weights to a `.safetensors` file.
|
||||
///
|
||||
/// ### See Also
|
||||
/// - ``convert(model:layers:)``
|
||||
/// - ``loadLoRAWeights(model:url:)``
|
||||
public static func saveLoRAWeights(model: Module, url: URL) throws {
|
||||
let parameters = Dictionary(
|
||||
uniqueKeysWithValues: model.trainableParameters().flattened())
|
||||
try save(arrays: parameters, url: url)
|
||||
}
|
||||
|
||||
public enum Progress: CustomStringConvertible {
|
||||
case train(
|
||||
iteration: Int, trainingLoss: Float, iterationsPerSecond: Double,
|
||||
tokensPerSecond: Double)
|
||||
case validation(iteration: Int, validationLoss: Float, validationTime: Double)
|
||||
case save(iteration: Int, url: URL)
|
||||
|
||||
public var description: String {
|
||||
switch self {
|
||||
case .train(
|
||||
let iteration, let trainingLoss, let iterationsPerSecond, let tokensPerSecond):
|
||||
"Iteration \(iteration + 1): training loss \(trainingLoss.formatted()), "
|
||||
+ "iterations/sec \(iterationsPerSecond.formatted()), "
|
||||
+ "Tokens/sec \(tokensPerSecond.formatted())"
|
||||
case .validation(let iteration, let validationLoss, let validationTime):
|
||||
"Iteration \(iteration + 1): "
|
||||
+ "validation loss \(validationLoss.formatted()), "
|
||||
+ "validation time \(validationTime.formatted())s"
|
||||
case .save(let iteration, let url):
|
||||
"Iteration \(iteration + 1): saved weights to \(url.path())"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public enum ProgressDisposition {
|
||||
case stop
|
||||
case more
|
||||
}
|
||||
|
||||
/// Train (or continue training) LoRA weights.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: model to train
|
||||
/// - train: training dataset
|
||||
/// - validate: validate dataset
|
||||
/// - optimizer: optimizer used in training
|
||||
/// - loss: loss function
|
||||
/// - tokenizer: tokenizer
|
||||
/// - parameters: training parameters
|
||||
/// - progress: progress callback
|
||||
public static func train(
|
||||
model: Module, train: [String], validate: [String], optimizer: Optimizer,
|
||||
loss: @escaping LoraLossFunction = loss, tokenizer: Tokenizer, parameters: Parameters,
|
||||
progress: (Progress) async -> ProgressDisposition
|
||||
) async throws {
|
||||
// def train(model, train_set, val_set, optimizer, loss, tokenizer, args)
|
||||
|
||||
let lossValueGrad = valueAndGrad(model: model) { model, arrays in
|
||||
let (ce, ntoks) = loss(model, arrays[0], arrays[1], arrays[2])
|
||||
return [ce, ntoks]
|
||||
}
|
||||
|
||||
var losses = [Float]()
|
||||
var tokenCount = 0
|
||||
|
||||
var start = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
for (iteration, (inputs, targets, lengths)) in LoRABatchIterator(
|
||||
dataset: train, tokenizer: tokenizer, batchSize: parameters.batchSize, train: true
|
||||
).enumerated() {
|
||||
// forward and backward pass
|
||||
let (resultArray, grad) = lossValueGrad(model, [inputs, targets, lengths])
|
||||
let lvalue = resultArray[0]
|
||||
let tokens = resultArray[1]
|
||||
|
||||
// model update
|
||||
optimizer.update(model: model, gradients: grad)
|
||||
eval(model, optimizer, lvalue)
|
||||
|
||||
// record loss
|
||||
losses.append(lvalue.item(Float.self))
|
||||
tokenCount += tokens.item(Int.self)
|
||||
|
||||
// report training loss
|
||||
if (iteration + 1) % parameters.stepsPerReport == 0 {
|
||||
let trainingLoss = MLXArray(losses).mean(stream: .cpu).item(Float.self)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
let iterationsPerSecond = Double(parameters.stepsPerReport) / (now - start)
|
||||
let tokensPerSecond = Double(tokenCount) / (now - start)
|
||||
|
||||
if await progress(
|
||||
.train(
|
||||
iteration: iteration, trainingLoss: trainingLoss,
|
||||
iterationsPerSecond: iterationsPerSecond, tokensPerSecond: tokensPerSecond))
|
||||
== .stop
|
||||
{
|
||||
break
|
||||
}
|
||||
|
||||
losses.removeAll()
|
||||
tokenCount = 0
|
||||
start = Date.timeIntervalSinceReferenceDate
|
||||
}
|
||||
|
||||
// report validation loss
|
||||
if iteration == 0 || (iteration + 1) % parameters.stepsPerEval == 0 {
|
||||
let validationStart = Date.timeIntervalSinceReferenceDate
|
||||
let validationLoss = evaluate(
|
||||
model: model, dataset: validate, loss: loss, tokenizer: tokenizer,
|
||||
batchSize: parameters.batchSize, batchCount: parameters.validationBatches)
|
||||
let now = Date.timeIntervalSinceReferenceDate
|
||||
|
||||
if await progress(
|
||||
.validation(
|
||||
iteration: iteration, validationLoss: validationLoss,
|
||||
validationTime: now - validationStart)) == .stop
|
||||
{
|
||||
break
|
||||
}
|
||||
|
||||
start = Date.timeIntervalSinceReferenceDate
|
||||
}
|
||||
|
||||
// save adapter weights if needed
|
||||
if let adapterURL = parameters.adapterURL, (iteration + 1) % parameters.saveEvery == 0 {
|
||||
try saveLoRAWeights(model: model, url: adapterURL)
|
||||
|
||||
if await progress(.save(iteration: iteration, url: adapterURL)) == .stop {
|
||||
break
|
||||
}
|
||||
|
||||
start = Date.timeIntervalSinceReferenceDate
|
||||
}
|
||||
|
||||
if iteration + 1 >= parameters.iterations {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import Foundation
|
||||
import Hub
|
||||
|
||||
/// Registry of models and and any overrides that go with them, e.g. prompt augmentation.
|
||||
/// If asked for an unknown configuration this will use the model/tokenizer as-is.
|
||||
@@ -9,7 +10,22 @@ import Foundation
|
||||
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
|
||||
/// implementation, if needed.
|
||||
public struct ModelConfiguration {
|
||||
public let id: String
|
||||
|
||||
public enum Identifier {
|
||||
case id(String)
|
||||
case directory(URL)
|
||||
}
|
||||
|
||||
public var id: Identifier
|
||||
|
||||
public var name: String {
|
||||
switch id {
|
||||
case .id(let string):
|
||||
string
|
||||
case .directory(let url):
|
||||
url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent
|
||||
}
|
||||
}
|
||||
|
||||
/// pull the tokenizer from an alternate id
|
||||
public let tokenizerId: String?
|
||||
@@ -26,7 +42,17 @@ public struct ModelConfiguration {
|
||||
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
|
||||
preparePrompt: ((String) -> String)? = nil
|
||||
) {
|
||||
self.id = id
|
||||
self.id = .id(id)
|
||||
self.tokenizerId = tokenizerId
|
||||
self.overrideTokenizer = overrideTokenizer
|
||||
self.preparePrompt = preparePrompt
|
||||
}
|
||||
|
||||
public init(
|
||||
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
|
||||
preparePrompt: ((String) -> String)? = nil
|
||||
) {
|
||||
self.id = .directory(directory)
|
||||
self.tokenizerId = tokenizerId
|
||||
self.overrideTokenizer = overrideTokenizer
|
||||
self.preparePrompt = preparePrompt
|
||||
@@ -36,13 +62,25 @@ public struct ModelConfiguration {
|
||||
preparePrompt?(prompt) ?? prompt
|
||||
}
|
||||
|
||||
public func modelDirectory(hub: HubApi = HubApi()) -> URL {
|
||||
switch id {
|
||||
case .id(let id):
|
||||
// download the model weights and config
|
||||
let repo = Hub.Repo(id: id)
|
||||
return hub.localRepoLocation(repo)
|
||||
|
||||
case .directory(let directory):
|
||||
return directory
|
||||
}
|
||||
}
|
||||
|
||||
public static var registry = [String: ModelConfiguration]()
|
||||
|
||||
public static func register(configurations: [ModelConfiguration]) {
|
||||
bootstrap()
|
||||
|
||||
for c in configurations {
|
||||
registry[c.id] = c
|
||||
registry[c.name] = c
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -240,3 +240,11 @@ public struct PhiConfiguration: Codable {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension PhiModel: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.selfAttention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,3 +251,11 @@ public struct Qwen2Configuration: Codable {
|
||||
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension Qwen2Model: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Llama
|
||||
# LLM
|
||||
|
||||
This is a port of several models from:
|
||||
|
||||
@@ -6,7 +6,7 @@ This is a port of several models from:
|
||||
|
||||
using the Hugging Face swift transformers package to provide tokenization:
|
||||
|
||||
https://github.com/huggingface/swift-transformers
|
||||
- https://github.com/huggingface/swift-transformers
|
||||
|
||||
The [Models.swift](Models.swift) provides minor overrides and customization --
|
||||
if you require overrides for the tokenizer or prompt customizations they can be
|
||||
@@ -30,3 +30,12 @@ Currently supported model types are:
|
||||
See [Configuration.swift](Configuration.swift) for more info.
|
||||
|
||||
See [llm-tool](../../Tools/llm-tool)
|
||||
|
||||
# LoRA
|
||||
|
||||
[Lora.swift](Lora.swift) contains an implementation of LoRA based on this example:
|
||||
|
||||
- https://github.com/ml-explore/mlx-examples/tree/main/lora
|
||||
|
||||
See [llm-tool/LoraCommands.swift](../../Tools/llm-tool/LoraCommands.swift) for an example of a driver and
|
||||
[llm-tool](../../Tools/llm-tool) for examples of how to run it.
|
||||
|
||||
@@ -254,3 +254,11 @@ public struct Starcoder2Configuration: Codable {
|
||||
?? true
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - LoRA
|
||||
|
||||
extension Starcoder2Model: LoRAModel {
|
||||
public func loraLinearLayers() -> LoRALinearLayers {
|
||||
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,20 @@ import Foundation
|
||||
import Hub
|
||||
import Tokenizers
|
||||
|
||||
public func loadTokenizer(configuration: ModelConfiguration) async throws -> Tokenizer {
|
||||
public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer
|
||||
{
|
||||
// from AutoTokenizer.from() -- this lets us override parts of the configuration
|
||||
let config = LanguageModelConfigurationFromHub(
|
||||
modelName: configuration.tokenizerId ?? configuration.id)
|
||||
|
||||
let config: LanguageModelConfigurationFromHub
|
||||
|
||||
switch configuration.id {
|
||||
case .id(let id):
|
||||
config = LanguageModelConfigurationFromHub(
|
||||
modelName: configuration.tokenizerId ?? id, hubApi: hub)
|
||||
case .directory(let directory):
|
||||
config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub)
|
||||
}
|
||||
|
||||
guard var tokenizerConfig = try await config.tokenizerConfig else {
|
||||
throw LLMError(message: "missing config")
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ let package = Package(
|
||||
.product(name: "MLX", package: "mlx-swift"),
|
||||
.product(name: "MLXFast", package: "mlx-swift"),
|
||||
.product(name: "MLXNN", package: "mlx-swift"),
|
||||
.product(name: "MLXOptimizers", package: "mlx-swift"),
|
||||
.product(name: "MLXRandom", package: "mlx-swift"),
|
||||
.product(name: "Transformers", package: "swift-transformers"),
|
||||
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
|
||||
@@ -43,6 +44,7 @@ let package = Package(
|
||||
.product(name: "MLX", package: "mlx-swift"),
|
||||
.product(name: "MLXFast", package: "mlx-swift"),
|
||||
.product(name: "MLXNN", package: "mlx-swift"),
|
||||
.product(name: "MLXOptimizers", package: "mlx-swift"),
|
||||
.product(name: "MLXRandom", package: "mlx-swift"),
|
||||
.product(name: "Transformers", package: "swift-transformers"),
|
||||
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
|
||||
|
||||
15
Tools/llm-tool/Arguments.swift
Normal file
15
Tools/llm-tool/Arguments.swift
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
|
||||
/// Extension to allow URL command line arguments.
|
||||
extension URL: ExpressibleByArgument {
|
||||
public init?(argument: String) {
|
||||
if argument.contains("://") {
|
||||
self.init(string: argument)
|
||||
} else {
|
||||
self.init(filePath: argument)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,18 +11,26 @@ import Tokenizers
|
||||
struct LLMTool: AsyncParsableCommand {
|
||||
static var configuration = CommandConfiguration(
|
||||
abstract: "Command line tool for generating text and manipulating LLMs",
|
||||
subcommands: [EvaluateCommand.self],
|
||||
subcommands: [EvaluateCommand.self, LoRACommand.self],
|
||||
defaultSubcommand: EvaluateCommand.self)
|
||||
}
|
||||
|
||||
/// Command line arguments for loading a model.
|
||||
struct ModelArguments: ParsableArguments {
|
||||
|
||||
@Option(name: .long, help: "Name of the huggingface model")
|
||||
@Option(name: .long, help: "Name of the huggingface model or absolute path to directory")
|
||||
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
|
||||
|
||||
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||
let modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
let modelConfiguration: ModelConfiguration
|
||||
|
||||
if self.model.hasPrefix("/") {
|
||||
// path
|
||||
modelConfiguration = ModelConfiguration(directory: URL(filePath: self.model))
|
||||
} else {
|
||||
// identifier
|
||||
modelConfiguration = ModelConfiguration.configuration(id: model)
|
||||
}
|
||||
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
|
||||
return (model, tokenizer, modelConfiguration)
|
||||
}
|
||||
@@ -31,7 +39,11 @@ struct ModelArguments: ParsableArguments {
|
||||
/// Command line arguments for controlling generation of text.
|
||||
struct GenerateArguments: ParsableArguments {
|
||||
|
||||
@Option(name: .shortAndLong, help: "The message to be processed by the model")
|
||||
@Option(
|
||||
name: .shortAndLong,
|
||||
help:
|
||||
"The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt"
|
||||
)
|
||||
var prompt = "compare python and swift"
|
||||
|
||||
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
|
||||
@@ -52,18 +64,32 @@ struct GenerateArguments: ParsableArguments {
|
||||
@Option(name: .long, help: "The PRNG seed")
|
||||
var seed: UInt64 = 0
|
||||
|
||||
@Flag(name: .shortAndLong, help: "If true only print the generated output")
|
||||
var quiet = false
|
||||
|
||||
var generateParameters: GenerateParameters {
|
||||
GenerateParameters(
|
||||
temperature: temperature, topP: topP, repetitionPenalty: repetitionPenalty,
|
||||
repetitionContextSize: repetitionContextSize)
|
||||
}
|
||||
|
||||
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
|
||||
{
|
||||
func resolvePrompt() throws -> String {
|
||||
if prompt.hasPrefix("@") {
|
||||
let names = prompt.split(separator: ",").map { String($0.dropFirst()) }
|
||||
return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n")
|
||||
} else {
|
||||
return prompt
|
||||
}
|
||||
}
|
||||
|
||||
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) throws -> (
|
||||
String, [Int]
|
||||
) {
|
||||
MLXRandom.seed(seed)
|
||||
|
||||
let prompt = configuration.prepare(prompt: self.prompt)
|
||||
let promptTokens = tokenizer.encode(text: prompt)
|
||||
let prompt = try resolvePrompt()
|
||||
let preparedPrompt = configuration.prepare(prompt: prompt)
|
||||
let promptTokens = tokenizer.encode(text: preparedPrompt)
|
||||
|
||||
return (prompt, promptTokens)
|
||||
}
|
||||
@@ -187,21 +213,27 @@ struct EvaluateCommand: AsyncParsableCommand {
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, modelConfiguration) = try await memory.start(args.load)
|
||||
|
||||
print("Model loaded -> \(modelConfiguration.id)")
|
||||
if !generate.quiet {
|
||||
print("Model loaded -> \(modelConfiguration.id)")
|
||||
}
|
||||
|
||||
let (prompt, promptTokens) = generate.tokenizePrompt(
|
||||
let (prompt, promptTokens) = try generate.tokenizePrompt(
|
||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
if !generate.quiet {
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
}
|
||||
|
||||
let result = await generate.generate(
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
|
||||
|
||||
print()
|
||||
print("------")
|
||||
print(result.summary())
|
||||
|
||||
memory.reportMemoryStatistics()
|
||||
if !generate.quiet {
|
||||
print("------")
|
||||
print(result.summary())
|
||||
|
||||
memory.reportMemoryStatistics()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
281
Tools/llm-tool/LoraCommands.swift
Normal file
281
Tools/llm-tool/LoraCommands.swift
Normal file
@@ -0,0 +1,281 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
import Hub
|
||||
import LLM
|
||||
import MLX
|
||||
import MLXNN
|
||||
import MLXOptimizers
|
||||
import MLXRandom
|
||||
import Tokenizers
|
||||
|
||||
struct LoRACommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "lora",
|
||||
abstract: "LoRA commands",
|
||||
subcommands: [
|
||||
LoRATrainCommand.self, LoRAFuseCommand.self, LoRATestCommand.self, LoRAEvalCommand.self,
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
/// Common arguments for loading a LoRA mdoel with adapter weights
|
||||
struct LoRAModelArguments: ParsableArguments {
|
||||
|
||||
@OptionGroup var args: ModelArguments
|
||||
|
||||
@Option(name: .long, help: "Save/load path for the trained adapter weights")
|
||||
public var adapter: URL = URL(filePath: "adapters.safetensors")
|
||||
|
||||
@Option(name: .long, help: "Number of layers to fine-tune")
|
||||
public var loraLayers = 16
|
||||
|
||||
/// Load the model and apply the LoRA adapters.
|
||||
///
|
||||
/// This does not load the adapter weights as they may not exist yet.
|
||||
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
|
||||
let (model, tokenizer, modelConfiguration) = try await args.load()
|
||||
|
||||
// convert some of the Linear layers to LoRALinear
|
||||
LoRATrain.convert(model: model, layers: loraLayers(model: model))
|
||||
|
||||
return (model, tokenizer, modelConfiguration)
|
||||
}
|
||||
|
||||
func loraLayers(model: Module) -> LoRALinearLayers {
|
||||
guard let layerProvider = model as? LoRAModel else {
|
||||
// the layerProvider will indicate which Linear layers need to be replaced
|
||||
fatalError(
|
||||
"Model \(type(of: model)) (\(args.model)) must implement the LoRALayerProvider protocol"
|
||||
)
|
||||
}
|
||||
|
||||
return Array(layerProvider.loraLinearLayers().suffix(loraLayers))
|
||||
}
|
||||
|
||||
func describe(model: Module) {
|
||||
let totalParameterCount = model.parameters()
|
||||
.flattenedValues().map { $0.size }.reduce(0, +)
|
||||
let trainableParameterCount = model.trainableParameters()
|
||||
.flattenedValues().map { $0.size }.reduce(0, +)
|
||||
|
||||
print("Model: \(args.model)")
|
||||
print("Total parameters: \((totalParameterCount / 1_000_000).formatted())M")
|
||||
print(
|
||||
"Trainable parameters: \((Float(trainableParameterCount) / 1_000_000).formatted(.number.precision(.significantDigits(1 ..< 4))))M"
|
||||
)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
struct LoRATrainCommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "train",
|
||||
abstract: "LoRA training"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LoRAModelArguments
|
||||
@OptionGroup var memory: MemoryArguments
|
||||
|
||||
@Flag(help: "Resume training with the given adapter file")
|
||||
public var resume = false
|
||||
|
||||
@Option(name: .long, help: "Directory with {train, valid, test}.{jsonl,txt} files")
|
||||
public var data: URL = URL(filePath: "data")
|
||||
|
||||
@Option(name: .long, help: "Learning rate for the optimizer")
|
||||
public var learningRate: Float = 1e-5
|
||||
|
||||
@Option(name: .long, help: "Number of dataset items to evaluate per iteration (batch)")
|
||||
public var batchSize = 4
|
||||
|
||||
@Option(name: .long, help: "Number iterations to train for")
|
||||
public var iterations = 1000
|
||||
|
||||
@Option(name: .long, help: "Number of iterations between loss reporting")
|
||||
public var stepsPerReport = 10
|
||||
|
||||
@Option(name: .long, help: "Number of iterations between validations")
|
||||
public var stepsPerEval = 100
|
||||
|
||||
@Option(name: .long, help: "Number of validation batches, 0 uses the entire set")
|
||||
public var validationBatches = 10
|
||||
|
||||
@Option(name: .long, help: "Number of iterations between checkpointing the adapter weights")
|
||||
public var saveEvery = 100
|
||||
|
||||
var parameters: LoRATrain.Parameters {
|
||||
var p = LoRATrain.Parameters()
|
||||
p.batchSize = self.batchSize
|
||||
p.iterations = self.iterations
|
||||
p.stepsPerReport = self.stepsPerReport
|
||||
p.stepsPerEval = self.stepsPerEval
|
||||
p.validationBatches = self.validationBatches
|
||||
p.saveEvery = self.saveEvery
|
||||
p.adapterURL = args.adapter
|
||||
return p
|
||||
}
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, _) = try await args.load()
|
||||
args.describe(model: model)
|
||||
|
||||
memory.start()
|
||||
|
||||
if resume {
|
||||
print("Loading pretrained adapters from \(args.adapter.path())")
|
||||
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
|
||||
}
|
||||
|
||||
// load the train/validation data
|
||||
let train = try loadLoRAData(directory: data, name: "train")
|
||||
let valid = try loadLoRAData(directory: data, name: "valid")
|
||||
|
||||
if train.isEmpty {
|
||||
fatalError("Training set is empty: \(data.path()))")
|
||||
}
|
||||
if valid.isEmpty {
|
||||
fatalError("Validation set is empty: \(data.path()))")
|
||||
}
|
||||
|
||||
// train
|
||||
let optimizer = Adam(learningRate: learningRate)
|
||||
try await LoRATrain.train(
|
||||
model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
|
||||
parameters: parameters
|
||||
) { progress in
|
||||
print(progress)
|
||||
return .more
|
||||
}
|
||||
try LoRATrain.saveLoRAWeights(model: model, url: args.adapter)
|
||||
}
|
||||
}
|
||||
|
||||
struct LoRAFuseCommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "fuse",
|
||||
abstract: "Fuse lora adapter weights back in to original model"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LoRAModelArguments
|
||||
|
||||
@Flag(name: .long, help: "De-quantize QuantizedLinear layers back into Linear")
|
||||
var deQuantize = false
|
||||
|
||||
@Option(name: .long, help: "Hub ID (mlx-community/mistral-lora) or path (/tmp/mistral-lora)")
|
||||
var output: String
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let outputURL: URL
|
||||
if output.hasPrefix("/") {
|
||||
outputURL = URL(filePath: output)
|
||||
} else {
|
||||
let repo = HubApi.Repo(id: output)
|
||||
outputURL = HubApi().localRepoLocation(repo)
|
||||
}
|
||||
|
||||
let (model, _, modelConfiguration) = try await args.load()
|
||||
|
||||
// load the prepared weights
|
||||
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
|
||||
|
||||
// fuse them back into Linear/QuantizedLinear
|
||||
LoRATrain.fuse(model: model, layers: args.loraLayers(model: model), deQuantize: deQuantize)
|
||||
|
||||
// make the new directory and copy files from source model
|
||||
try FileManager.default.createDirectory(at: outputURL, withIntermediateDirectories: true)
|
||||
let inputURL = modelConfiguration.modelDirectory()
|
||||
let enumerator = FileManager.default.enumerator(
|
||||
at: inputURL, includingPropertiesForKeys: nil)!
|
||||
for case let url as URL in enumerator {
|
||||
// copy everything except the model weights -- we will write out the fused one below
|
||||
if url.pathExtension == "safetensors" {
|
||||
continue
|
||||
}
|
||||
|
||||
try FileManager.default.copyItem(
|
||||
at: url, to: outputURL.appending(component: url.lastPathComponent))
|
||||
}
|
||||
|
||||
// write them back out
|
||||
let weights = Dictionary(uniqueKeysWithValues: model.parameters().flattened())
|
||||
try save(arrays: weights, url: outputURL.appending(component: "weights.safetensors"))
|
||||
|
||||
print("Fused weights written to \(outputURL.path())")
|
||||
print("Use with:\n\tllm-tool eval --model \(output)")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
struct LoRATestCommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "test",
|
||||
abstract: "LoRA testing"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LoRAModelArguments
|
||||
@OptionGroup var memory: MemoryArguments
|
||||
|
||||
@Option(name: .long, help: "Directory with {train, valid, test}.{jsonl,txt} files")
|
||||
public var data: URL = URL(filePath: "data")
|
||||
|
||||
@Option(name: .long, help: "Minibatch size")
|
||||
public var batchSize = 4
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, _) = try await args.load()
|
||||
args.describe(model: model)
|
||||
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
|
||||
|
||||
memory.start()
|
||||
|
||||
let test = try loadLoRAData(directory: data, name: "test")
|
||||
let loss = LoRATrain.evaluate(
|
||||
model: model, dataset: test, tokenizer: tokenizer, batchSize: batchSize, batchCount: 0)
|
||||
|
||||
print("Test loss \(loss.formatted()), ppl \(exp(loss).formatted())")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
struct LoRAEvalCommand: AsyncParsableCommand {
|
||||
|
||||
static var configuration = CommandConfiguration(
|
||||
commandName: "eval",
|
||||
abstract: "LoRA evaluation"
|
||||
)
|
||||
|
||||
@OptionGroup var args: LoRAModelArguments
|
||||
@OptionGroup var memory: MemoryArguments
|
||||
@OptionGroup var generate: GenerateArguments
|
||||
|
||||
@MainActor
|
||||
mutating func run() async throws {
|
||||
let (model, tokenizer, modelConfiguration) = try await args.load()
|
||||
args.describe(model: model)
|
||||
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
|
||||
|
||||
memory.start()
|
||||
|
||||
let (prompt, promptTokens) = try generate.tokenizePrompt(
|
||||
configuration: modelConfiguration, tokenizer: tokenizer)
|
||||
|
||||
if !generate.quiet {
|
||||
print("Starting generation ...")
|
||||
print(prompt, terminator: "")
|
||||
}
|
||||
|
||||
// generate and print the result
|
||||
let _ = await generate.generate(
|
||||
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
|
||||
print()
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,7 @@ See various READMEs:
|
||||
|
||||
Build the `llm-tool` scheme in Xcode.
|
||||
|
||||
### Running (Xcode)
|
||||
### Running: Xcode
|
||||
|
||||
To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
|
||||
|
||||
@@ -30,7 +30,7 @@ The model should be a path in the Hugging Face repository, e.g.:
|
||||
|
||||
See [LLM](../../Libraries/LLM/README.md) for more info.
|
||||
|
||||
### Running (Command Line)
|
||||
### Running: Command Line
|
||||
|
||||
Use the `mlx-run` script to run the command line tools:
|
||||
|
||||
@@ -60,3 +60,184 @@ Building in Release / optimizations will remove a lot of tail calls in the C++
|
||||
layer. These lead to the stack overflows.
|
||||
|
||||
See discussion here: https://github.com/ml-explore/mlx-swift-examples/issues/3
|
||||
|
||||
## LoRA
|
||||
|
||||
`llm-tool` provides an example LoRA driver based on:
|
||||
|
||||
- https://github.com/ml-explore/mlx-examples/blob/main/lora/README.md
|
||||
|
||||
This is an example of using MLX to fine-tune an LLM with low rank adaptation
|
||||
(LoRA) for a target task.[^lora] The example also supports quantized LoRA
|
||||
(QLoRA).[^qlora] The example works with Llama and Mistral style models
|
||||
available on Hugging Face.
|
||||
|
||||
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
|
||||
generate SQL queries from natural language. However, the example is intended to
|
||||
be general should you wish to use a custom dataset.
|
||||
|
||||
> Note: Some of the prompts have newlines in them which is difficult to achieve via running in Xcode.
|
||||
|
||||
Running `llm-tool lora` will produce help:
|
||||
|
||||
```
|
||||
SUBCOMMANDS:
|
||||
train LoRA training
|
||||
fuse Fuse lora adapter weights back in to original model
|
||||
test LoRA testing
|
||||
eval LoRA evaluation
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
The first step will be training the LoRA adapter. Example training data
|
||||
is available in $SRCROOT/Data/lora. You can use your
|
||||
own data in either `jsonl` or `txt` format with one entry per line.
|
||||
|
||||
We need to specify a number of parameters:
|
||||
|
||||
- `--model` -- which model to use. This can be quantized [^qlora] or not [^lora]
|
||||
- `--data` -- directory with the test, train and valid files. These can be either `jsonl` or `txt` files
|
||||
- `--adapter` -- path to a safetensors file to write the fine tuned parameters into
|
||||
|
||||
Additionally the performance of the fine tuning can be controlled with:
|
||||
|
||||
- `--batch-size` -- size of the minibatches to run in the training loop, e.g. how many prompts to process per iteration
|
||||
- `--lora-layers` -- the number of layers in the Attention section of the model to adapt and train
|
||||
- `--iterations` -- the number of iterations to train for
|
||||
|
||||
If desired, the amount of memory used can be adjusted with:
|
||||
|
||||
- `--cache-size` -- the number shown below limits the cache size to 1024M
|
||||
|
||||
Here is an example run using adapters on the last 4 layers of the model:
|
||||
|
||||
```
|
||||
./mlx-run llm-tool lora train \
|
||||
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
|
||||
--data Data/lora \
|
||||
--adapter /tmp/lora-layers-4.safetensors \
|
||||
--batch-size 1 --lora-layers 4 \
|
||||
--cache-size 1024
|
||||
```
|
||||
|
||||
giving output like this:
|
||||
|
||||
```
|
||||
Model: mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||
Total parameters: 1,242M
|
||||
Trainable parameters: 0.426M
|
||||
Iteration 1: validation loss 2.443872, validation time 3.330629s
|
||||
Iteration 10: training loss 2.356848, iterations/sec 2.640604, Tokens/sec 260.363581
|
||||
Iteration 20: training loss 2.063395, iterations/sec 2.294999, Tokens/sec 232.483365
|
||||
Iteration 30: training loss 1.63846, iterations/sec 2.279401, Tokens/sec 225.204788
|
||||
Iteration 40: training loss 1.66366, iterations/sec 2.493669, Tokens/sec 218.196057
|
||||
Iteration 50: training loss 1.470927, iterations/sec 2.301153, Tokens/sec 231.72614
|
||||
Iteration 60: training loss 1.396581, iterations/sec 2.400012, Tokens/sec 230.401195
|
||||
Iteration 70: training loss 1.587023, iterations/sec 2.422193, Tokens/sec 218.966258
|
||||
Iteration 80: training loss 1.376895, iterations/sec 2.111973, Tokens/sec 216.477187
|
||||
Iteration 90: training loss 1.245127, iterations/sec 2.383802, Tokens/sec 214.065436
|
||||
Iteration 100: training loss 1.344523, iterations/sec 2.424746, Tokens/sec 223.076649
|
||||
Iteration 100: validation loss 1.400582, validation time 3.489797s
|
||||
Iteration 100: saved weights to /tmp/lora.safetensors
|
||||
...
|
||||
Iteration 910: training loss 1.181306, iterations/sec 2.355085, Tokens/sec 212.428628
|
||||
Iteration 920: training loss 1.042286, iterations/sec 2.374377, Tokens/sec 222.479127
|
||||
Iteration 930: training loss 0.920768, iterations/sec 2.475088, Tokens/sec 220.035347
|
||||
Iteration 940: training loss 1.140762, iterations/sec 2.119886, Tokens/sec 227.039828
|
||||
Iteration 950: training loss 1.068073, iterations/sec 2.523047, Tokens/sec 218.495903
|
||||
Iteration 960: training loss 1.106662, iterations/sec 2.339293, Tokens/sec 221.063186
|
||||
Iteration 970: training loss 0.833658, iterations/sec 2.474683, Tokens/sec 213.56517
|
||||
Iteration 980: training loss 0.844026, iterations/sec 2.441064, Tokens/sec 210.663791
|
||||
Iteration 990: training loss 0.903735, iterations/sec 2.253876, Tokens/sec 218.175162
|
||||
Iteration 1000: training loss 0.872615, iterations/sec 2.343899, Tokens/sec 219.62336
|
||||
Iteration 1000: validation loss 0.714194, validation time 3.470462s
|
||||
Iteration 1000: saved weights to /tmp/lora-layers-4.safetensors
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
You can test the LoRA adapated model against the `test` dataset using this command:
|
||||
|
||||
```
|
||||
./mlx-run llm-tool lora test \
|
||||
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
|
||||
--data Data/lora \
|
||||
--adapter /tmp/lora-layers-4.safetensors \
|
||||
--batch-size 1 --lora-layers 4 \
|
||||
--cache-size 1024
|
||||
```
|
||||
|
||||
This will run all the items (100 in the example data we are using) in the test set and compute the loss:
|
||||
|
||||
```
|
||||
Model: mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||
Total parameters: 1,242M
|
||||
Trainable parameters: 0.426M
|
||||
Test loss 1.327623, ppl 3.772065
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
|
||||
Next you can evaluate your own prompts with the fine tuned LoRA adapters. It is important to
|
||||
follow the prompt example from the training data to match the format:
|
||||
|
||||
```
|
||||
{"text": "table: 1-10015132-1\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What school did player number 6 come from?\nA: SELECT School/Club Team FROM 1-10015132-1 WHERE No. = '6'"}
|
||||
```
|
||||
|
||||
Given that format you might issue a command like this:
|
||||
|
||||
```
|
||||
./mlx-run llm-tool lora eval \
|
||||
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
|
||||
--adapter /tmp/lora-layers-4.safetensors \
|
||||
--lora-layers 4 \
|
||||
--prompt "table: 1-10015132-16
|
||||
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
|
||||
Q: What is terrence ross' nationality
|
||||
A: "
|
||||
```
|
||||
|
||||
> Note: the prompt has newlines in it to match the format of the fine tuned prompts -- this may be easier to do with the command line than Xcode.
|
||||
|
||||
You might be treated to a response like this:
|
||||
|
||||
```
|
||||
Model: mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
|
||||
Total parameters: 1,242M
|
||||
Trainable parameters: 0.426M
|
||||
Starting generation ...
|
||||
table: 1-10015132-16
|
||||
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
|
||||
Q: What is terrence ross' nationality
|
||||
A: SELECT Nationality FROM 1-10015132-16 WHERE Player = 'Terrence Ross' AND No. = 1
|
||||
```
|
||||
|
||||
### Fusing
|
||||
|
||||
Once the adapter weights are trained you can produce new weights with the original achitecture that
|
||||
have the adapter weights merged in:
|
||||
|
||||
```
|
||||
./mlx-run llm-tool lora fuse \
|
||||
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
|
||||
--adapter /tmp/lora-layers-4.safetensors \
|
||||
--output mlx-community/mistral-lora
|
||||
```
|
||||
|
||||
outputs:
|
||||
|
||||
```
|
||||
Total parameters: 1,244M
|
||||
Trainable parameters: 0.426M
|
||||
Use with:
|
||||
llm-tool eval --model mlx-community/mistral-lora
|
||||
```
|
||||
|
||||
As noted in the output these new weights can be used with the original model architecture.
|
||||
|
||||
|
||||
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
|
||||
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
|
||||
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.
|
||||
|
||||
@@ -12,6 +12,15 @@
|
||||
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
|
||||
81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; };
|
||||
819BEFF82BAF8B4E0002CCEE /* DeviceStat.swift in Sources */ = {isa = PBXBuildFile; fileRef = 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */; };
|
||||
C3056BAE2BCD97B700A31D04 /* LoRATrainingExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3056BAD2BCD97B700A31D04 /* LoRATrainingExampleApp.swift */; };
|
||||
C3056BB02BCD97B700A31D04 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3056BAF2BCD97B700A31D04 /* ContentView.swift */; };
|
||||
C3056BB22BCD97B800A31D04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3056BB12BCD97B800A31D04 /* Assets.xcassets */; };
|
||||
C3056BB62BCD97B800A31D04 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3056BB52BCD97B800A31D04 /* Preview Assets.xcassets */; };
|
||||
C3056BBA2BCD981900A31D04 /* train.jsonl in Resources */ = {isa = PBXBuildFile; fileRef = C3056BA22BCD973400A31D04 /* train.jsonl */; };
|
||||
C3056BBB2BCD981900A31D04 /* test.jsonl in Resources */ = {isa = PBXBuildFile; fileRef = C3056BA12BCD973400A31D04 /* test.jsonl */; };
|
||||
C3056BBC2BCD981900A31D04 /* valid.jsonl in Resources */ = {isa = PBXBuildFile; fileRef = C3056BA32BCD973400A31D04 /* valid.jsonl */; };
|
||||
C3056BBD2BCD984F00A31D04 /* LLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; };
|
||||
C3056BBE2BCD984F00A31D04 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
|
||||
C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; };
|
||||
C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; };
|
||||
C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; };
|
||||
@@ -22,6 +31,11 @@
|
||||
C34E49292B6A028100FCB841 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C34E49282B6A028100FCB841 /* ArgumentParser */; };
|
||||
C34E492A2B6A028800FCB841 /* MNIST.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; };
|
||||
C34E492B2B6A028800FCB841 /* MNIST.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
|
||||
C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */; };
|
||||
C36BEFB22BBDE9D0002D4AFE /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C36BEFB12BBDE9D0002D4AFE /* MLXOptimizers */; };
|
||||
C36BEFB52BBDEAD8002D4AFE /* LoraCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFB32BBDEA69002D4AFE /* LoraCommands.swift */; };
|
||||
C36BEFB82BBDED51002D4AFE /* Arguments.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFB62BBDECBC002D4AFE /* Arguments.swift */; };
|
||||
C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */; };
|
||||
C382DE8A2B630889000F8F03 /* AsyncAlgorithms in Frameworks */ = {isa = PBXBuildFile; productRef = C382DE892B630889000F8F03 /* AsyncAlgorithms */; };
|
||||
C38935C82B869C7A0037B833 /* LLM.h in Headers */ = {isa = PBXBuildFile; fileRef = C38935C72B869C7A0037B833 /* LLM.h */; settings = {ATTRIBUTES = (Public, ); }; };
|
||||
C38935CC2B869C870037B833 /* Llama.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48EE2B696E6500FCB841 /* Llama.swift */; };
|
||||
@@ -69,6 +83,13 @@
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXContainerItemProxy section */
|
||||
C3056BBF2BCD984F00A31D04 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = C39273682B60697700368D5D /* Project object */;
|
||||
proxyType = 1;
|
||||
remoteGlobalIDString = C38935C42B869C7A0037B833;
|
||||
remoteInfo = LLM;
|
||||
};
|
||||
C34E492C2B6A028800FCB841 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = C39273682B60697700368D5D /* Project object */;
|
||||
@@ -100,6 +121,17 @@
|
||||
/* End PBXContainerItemProxy section */
|
||||
|
||||
/* Begin PBXCopyFilesBuildPhase section */
|
||||
C3056BC12BCD984F00A31D04 /* Embed Frameworks */ = {
|
||||
isa = PBXCopyFilesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
dstPath = "";
|
||||
dstSubfolderSpec = 10;
|
||||
files = (
|
||||
C3056BBE2BCD984F00A31D04 /* LLM.framework in Embed Frameworks */,
|
||||
);
|
||||
name = "Embed Frameworks";
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
C3288D712B6D9313009FF608 /* CopyFiles */ = {
|
||||
isa = PBXCopyFilesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
@@ -187,6 +219,17 @@
|
||||
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = "<group>"; };
|
||||
52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; };
|
||||
819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DeviceStat.swift; sourceTree = "<group>"; };
|
||||
C3056BA12BCD973400A31D04 /* test.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = test.jsonl; sourceTree = "<group>"; };
|
||||
C3056BA22BCD973400A31D04 /* train.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = train.jsonl; sourceTree = "<group>"; };
|
||||
C3056BA32BCD973400A31D04 /* valid.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = valid.jsonl; sourceTree = "<group>"; };
|
||||
C3056BA42BCD973400A31D04 /* wikisql.py */ = {isa = PBXFileReference; lastKnownFileType = text.script.python; path = wikisql.py; sourceTree = "<group>"; };
|
||||
C3056BAB2BCD97B700A31D04 /* LoRATrainingExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = LoRATrainingExample.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
C3056BAD2BCD97B700A31D04 /* LoRATrainingExampleApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LoRATrainingExampleApp.swift; sourceTree = "<group>"; };
|
||||
C3056BAF2BCD97B700A31D04 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
|
||||
C3056BB12BCD97B800A31D04 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
|
||||
C3056BB32BCD97B800A31D04 /* LoRATrainingExample.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = LoRATrainingExample.entitlements; sourceTree = "<group>"; };
|
||||
C3056BB52BCD97B800A31D04 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
|
||||
C3056BC42BCDAB8600A31D04 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
|
||||
C325DE3F2B648CDB00628871 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
|
||||
C3288D732B6D9313009FF608 /* LinearModelTraining */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = LinearModelTraining; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
C3288D752B6D9313009FF608 /* LinearModelTraining.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LinearModelTraining.swift; sourceTree = "<group>"; };
|
||||
@@ -202,6 +245,10 @@
|
||||
C34E49142B69C1E300FCB841 /* Files.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Files.swift; sourceTree = "<group>"; };
|
||||
C34E49212B6A026F00FCB841 /* mnist-tool */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "mnist-tool"; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
C34E49232B6A026F00FCB841 /* MNISTTool.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MNISTTool.swift; sourceTree = "<group>"; };
|
||||
C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Lora.swift; sourceTree = "<group>"; };
|
||||
C36BEFB32BBDEA69002D4AFE /* LoraCommands.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LoraCommands.swift; sourceTree = "<group>"; };
|
||||
C36BEFB62BBDECBC002D4AFE /* Arguments.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Arguments.swift; sourceTree = "<group>"; };
|
||||
C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "Lora+Data.swift"; sourceTree = "<group>"; };
|
||||
C38935C52B869C7A0037B833 /* LLM.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LLM.framework; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
C38935C72B869C7A0037B833 /* LLM.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LLM.h; sourceTree = "<group>"; };
|
||||
C38935DE2B869DD00037B833 /* Phi.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Phi.swift; sourceTree = "<group>"; };
|
||||
@@ -237,6 +284,14 @@
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
C3056BA82BCD97B700A31D04 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
C3056BBD2BCD984F00A31D04 /* LLM.framework in Frameworks */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
C3288D702B6D9313009FF608 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
@@ -273,6 +328,7 @@
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
C36BEFB22BBDE9D0002D4AFE /* MLXOptimizers in Frameworks */,
|
||||
C38935D22B869CC40037B833 /* MLXNN in Frameworks */,
|
||||
C38935D42B869CC40037B833 /* MLXRandom in Frameworks */,
|
||||
C38935D62B869CC40037B833 /* Transformers in Frameworks */,
|
||||
@@ -328,6 +384,46 @@
|
||||
path = ViewModels;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
C3056BA52BCD973400A31D04 /* lora */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
C3056BA12BCD973400A31D04 /* test.jsonl */,
|
||||
C3056BA22BCD973400A31D04 /* train.jsonl */,
|
||||
C3056BA32BCD973400A31D04 /* valid.jsonl */,
|
||||
C3056BA42BCD973400A31D04 /* wikisql.py */,
|
||||
);
|
||||
path = lora;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
C3056BA62BCD973400A31D04 /* Data */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
C3056BA52BCD973400A31D04 /* lora */,
|
||||
);
|
||||
path = Data;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
C3056BAC2BCD97B700A31D04 /* LoRATrainingExample */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
C3056BAD2BCD97B700A31D04 /* LoRATrainingExampleApp.swift */,
|
||||
C3056BAF2BCD97B700A31D04 /* ContentView.swift */,
|
||||
C3056BB12BCD97B800A31D04 /* Assets.xcassets */,
|
||||
C3056BB32BCD97B800A31D04 /* LoRATrainingExample.entitlements */,
|
||||
C3056BB42BCD97B800A31D04 /* Preview Content */,
|
||||
C3056BC42BCDAB8600A31D04 /* README.md */,
|
||||
);
|
||||
path = LoRATrainingExample;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
C3056BB42BCD97B800A31D04 /* Preview Content */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
C3056BB52BCD97B800A31D04 /* Preview Assets.xcassets */,
|
||||
);
|
||||
path = "Preview Content";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
C3288D742B6D9313009FF608 /* LinearModelTraining */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
@@ -340,8 +436,10 @@
|
||||
C34E48F32B696F0B00FCB841 /* llm-tool */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
C34E48F42B696F0B00FCB841 /* LLMTool.swift */,
|
||||
C34E48F92B69930300FCB841 /* README.md */,
|
||||
C34E48F42B696F0B00FCB841 /* LLMTool.swift */,
|
||||
C36BEFB32BBDEA69002D4AFE /* LoraCommands.swift */,
|
||||
C36BEFB62BBDECBC002D4AFE /* Arguments.swift */,
|
||||
);
|
||||
path = "llm-tool";
|
||||
sourceTree = "<group>";
|
||||
@@ -370,6 +468,8 @@
|
||||
C38935C62B869C7A0037B833 /* LLM */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */,
|
||||
C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */,
|
||||
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */,
|
||||
C34E48EF2B696E6500FCB841 /* Configuration.swift */,
|
||||
C3A8B3AB2B9283150002EFB8 /* Models.swift */,
|
||||
@@ -393,6 +493,7 @@
|
||||
children = (
|
||||
C325DE3F2B648CDB00628871 /* README.md */,
|
||||
F8D7023A2BB4E223003D7CF5 /* Package.swift */,
|
||||
C3056BA62BCD973400A31D04 /* Data */,
|
||||
C39273822B606A9200368D5D /* Libraries */,
|
||||
C3A8B3AD2B9294E30002EFB8 /* Applications */,
|
||||
C39273812B606A7400368D5D /* Tools */,
|
||||
@@ -412,6 +513,7 @@
|
||||
C38935C52B869C7A0037B833 /* LLM.framework */,
|
||||
C3A8B3B22B9295090002EFB8 /* MNISTTrainer.app */,
|
||||
C3A8B3DC2B92A29E0002EFB8 /* LLMEval.app */,
|
||||
C3056BAB2BCD97B700A31D04 /* LoRATrainingExample.app */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
@@ -454,6 +556,7 @@
|
||||
C3A8B3AD2B9294E30002EFB8 /* Applications */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
C3056BAC2BCD97B700A31D04 /* LoRATrainingExample */,
|
||||
C3A8B3EB2B92A2A90002EFB8 /* LLMEval */,
|
||||
C3A8B3C12B92951E0002EFB8 /* MNISTTrainer */,
|
||||
);
|
||||
@@ -527,6 +630,27 @@
|
||||
/* End PBXHeadersBuildPhase section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
C3056BAA2BCD97B700A31D04 /* LoRATrainingExample */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = C3056BB72BCD97B800A31D04 /* Build configuration list for PBXNativeTarget "LoRATrainingExample" */;
|
||||
buildPhases = (
|
||||
C3056BA72BCD97B700A31D04 /* Sources */,
|
||||
C3056BA82BCD97B700A31D04 /* Frameworks */,
|
||||
C3056BA92BCD97B700A31D04 /* Resources */,
|
||||
C3056BC12BCD984F00A31D04 /* Embed Frameworks */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
C3056BC02BCD984F00A31D04 /* PBXTargetDependency */,
|
||||
);
|
||||
name = LoRATrainingExample;
|
||||
packageProductDependencies = (
|
||||
);
|
||||
productName = LoRATrainingExample;
|
||||
productReference = C3056BAB2BCD97B700A31D04 /* LoRATrainingExample.app */;
|
||||
productType = "com.apple.product-type.application";
|
||||
};
|
||||
C3288D722B6D9313009FF608 /* LinearModelTraining */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = C3288D792B6D9313009FF608 /* Build configuration list for PBXNativeTarget "LinearModelTraining" */;
|
||||
@@ -617,6 +741,7 @@
|
||||
C38935D32B869CC40037B833 /* MLXRandom */,
|
||||
C38935D52B869CC40037B833 /* Transformers */,
|
||||
C38935DC2B869CEC0037B833 /* AsyncAlgorithms */,
|
||||
C36BEFB12BBDE9D0002D4AFE /* MLXOptimizers */,
|
||||
);
|
||||
productName = LLM;
|
||||
productReference = C38935C52B869C7A0037B833 /* LLM.framework */;
|
||||
@@ -716,9 +841,12 @@
|
||||
isa = PBXProject;
|
||||
attributes = {
|
||||
BuildIndependentTargetsInParallel = 1;
|
||||
LastSwiftUpdateCheck = 1520;
|
||||
LastSwiftUpdateCheck = 1530;
|
||||
LastUpgradeCheck = 1500;
|
||||
TargetAttributes = {
|
||||
C3056BAA2BCD97B700A31D04 = {
|
||||
CreatedOnToolsVersion = 15.3;
|
||||
};
|
||||
C3288D722B6D9313009FF608 = {
|
||||
CreatedOnToolsVersion = 15.0.1;
|
||||
};
|
||||
@@ -775,11 +903,24 @@
|
||||
C3288D722B6D9313009FF608 /* LinearModelTraining */,
|
||||
C3A8B3B12B9295090002EFB8 /* MNISTTrainer */,
|
||||
C3A8B3DB2B92A29D0002EFB8 /* LLMEval */,
|
||||
C3056BAA2BCD97B700A31D04 /* LoRATrainingExample */,
|
||||
);
|
||||
};
|
||||
/* End PBXProject section */
|
||||
|
||||
/* Begin PBXResourcesBuildPhase section */
|
||||
C3056BA92BCD97B700A31D04 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
C3056BBA2BCD981900A31D04 /* train.jsonl in Resources */,
|
||||
C3056BBB2BCD981900A31D04 /* test.jsonl in Resources */,
|
||||
C3056BBC2BCD981900A31D04 /* valid.jsonl in Resources */,
|
||||
C3056BB62BCD97B800A31D04 /* Preview Assets.xcassets in Resources */,
|
||||
C3056BB22BCD97B800A31D04 /* Assets.xcassets in Resources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
C34E490B2B69A92900FCB841 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
@@ -815,6 +956,15 @@
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
C3056BA72BCD97B700A31D04 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
C3056BB02BCD97B700A31D04 /* ContentView.swift in Sources */,
|
||||
C3056BAE2BCD97B700A31D04 /* LoRATrainingExampleApp.swift in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
C3288D6F2B6D9313009FF608 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
@@ -850,6 +1000,8 @@
|
||||
C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */,
|
||||
C38935CD2B869C870037B833 /* Configuration.swift in Sources */,
|
||||
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */,
|
||||
C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */,
|
||||
C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */,
|
||||
C38935DF2B869DD00037B833 /* Phi.swift in Sources */,
|
||||
C38935CE2B869C870037B833 /* Load.swift in Sources */,
|
||||
C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */,
|
||||
@@ -872,7 +1024,9 @@
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
C36BEFB82BBDED51002D4AFE /* Arguments.swift in Sources */,
|
||||
C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */,
|
||||
C36BEFB52BBDEAD8002D4AFE /* LoraCommands.swift in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
@@ -899,6 +1053,11 @@
|
||||
/* End PBXSourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXTargetDependency section */
|
||||
C3056BC02BCD984F00A31D04 /* PBXTargetDependency */ = {
|
||||
isa = PBXTargetDependency;
|
||||
target = C38935C42B869C7A0037B833 /* LLM */;
|
||||
targetProxy = C3056BBF2BCD984F00A31D04 /* PBXContainerItemProxy */;
|
||||
};
|
||||
C34E492D2B6A028800FCB841 /* PBXTargetDependency */ = {
|
||||
isa = PBXTargetDependency;
|
||||
target = C34E490C2B69A92900FCB841 /* MNIST */;
|
||||
@@ -922,6 +1081,183 @@
|
||||
/* End PBXTargetDependency section */
|
||||
|
||||
/* Begin XCBuildConfiguration section */
|
||||
C3056BB82BCD97B800A31D04 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
CODE_SIGN_ENTITLEMENTS = Applications/LoRATrainingExample/LoRATrainingExample.entitlements;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||
DEVELOPMENT_ASSET_PATHS = "\"Applications/LoRATrainingExample/Preview Content\"";
|
||||
DEVELOPMENT_TEAM = "";
|
||||
ENABLE_PREVIEWS = YES;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_TESTABILITY = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_DYNAMIC_NO_PIC = NO;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_OPTIMIZATION_LEVEL = 0;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||
"DEBUG=1",
|
||||
"$(inherited)",
|
||||
);
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault;
|
||||
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 17.2;
|
||||
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
|
||||
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MACOSX_DEPLOYMENT_TARGET = 14.2;
|
||||
MARKETING_VERSION = 1.0;
|
||||
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
|
||||
MTL_FAST_MATH = YES;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = mlx.LoRATrainingExample;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = auto;
|
||||
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator";
|
||||
SUPPORTS_MACCATALYST = NO;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2,7";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
C3056BB92BCD97B800A31D04 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
CODE_SIGN_ENTITLEMENTS = Applications/LoRATrainingExample/LoRATrainingExample.entitlements;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||
DEVELOPMENT_ASSET_PATHS = "\"Applications/LoRATrainingExample/Preview Content\"";
|
||||
DEVELOPMENT_TEAM = "";
|
||||
ENABLE_NS_ASSERTIONS = NO;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault;
|
||||
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 17.2;
|
||||
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
|
||||
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MACOSX_DEPLOYMENT_TARGET = 14.2;
|
||||
MARKETING_VERSION = 1.0;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
MTL_FAST_MATH = YES;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = mlx.LoRATrainingExample;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = auto;
|
||||
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator";
|
||||
SUPPORTS_MACCATALYST = NO;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2,7";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
C3288D772B6D9313009FF608 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
@@ -2140,6 +2476,15 @@
|
||||
/* End XCBuildConfiguration section */
|
||||
|
||||
/* Begin XCConfigurationList section */
|
||||
C3056BB72BCD97B800A31D04 /* Build configuration list for PBXNativeTarget "LoRATrainingExample" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
C3056BB82BCD97B800A31D04 /* Debug */,
|
||||
C3056BB92BCD97B800A31D04 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
C3288D792B6D9313009FF608 /* Build configuration list for PBXNativeTarget "LinearModelTraining" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
@@ -2295,6 +2640,11 @@
|
||||
package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */;
|
||||
productName = ArgumentParser;
|
||||
};
|
||||
C36BEFB12BBDE9D0002D4AFE /* MLXOptimizers */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */;
|
||||
productName = MLXOptimizers;
|
||||
};
|
||||
C382DE892B630889000F8F03 /* AsyncAlgorithms */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = C382DE882B630889000F8F03 /* XCRemoteSwiftPackageReference "swift-async-algorithms" */;
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||
"state" : {
|
||||
"branch" : "main",
|
||||
"revision" : "b4d3e4bbbe41e6dc7c46d5ba075049ae7177961b"
|
||||
"revision" : "cf2c5d20c8575b375cb0d97a06ae0199527b5f32"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user