implement LoRA / QLoRA (#46)

* implement LoRA / QLoRA

- example of using MLX to fine-tune an LLM with low rank adaptation (LoRA) for a target task
- see also https://arxiv.org/abs/2106.09685
- based on https://github.com/ml-explore/mlx-examples/tree/main/lora

* add some command line flags I found useful during use
- --quiet -- don't print decorator text, just the generated text
- --prompt @/tmp/file.txt -- load prompt from file

* user can specify path to model OR model identifier in huggingface

* update mlx-swift reference

Co-authored-by: Ashraful Islam <ashraful.meche@gmail.com>
Co-authored-by: JustinMeans <46542161+JustinMeans@users.noreply.github.com>
This commit is contained in:
David Koski
2024-04-22 09:30:12 -07:00
committed by GitHub
parent 7e85eb8b88
commit 6c0b66f90a
32 changed files with 3483 additions and 64 deletions

View File

@@ -187,7 +187,7 @@ class LLMEvaluator {
[modelConfiguration] progress in
DispatchQueue.main.sync {
self.modelInfo =
"Downloading \(modelConfiguration.id): \(Int(progress.fractionCompleted * 100))%"
"Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
}
}
self.modelInfo =

View File

@@ -0,0 +1,11 @@
{
"colors" : [
{
"idiom" : "universal"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,63 @@
{
"images" : [
{
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "16x16"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "32x32"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "128x128"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "256x256"
},
{
"idiom" : "mac",
"scale" : "1x",
"size" : "512x512"
},
{
"idiom" : "mac",
"scale" : "2x",
"size" : "512x512"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,284 @@
// Copyright © 2024 Apple Inc.
import LLM
import MLX
import MLXOptimizers
import MLXRandom
import SwiftUI
import Tokenizers
struct ContentView: View {
@State var evaluator = LoRAEvaluator()
@State var prompt = """
table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality
A:
"""
var body: some View {
VStack {
HStack {
if let progress = evaluator.progress {
if let current = progress.current, let limit = progress.limit {
ProgressView(progress.title, value: current, total: limit)
} else {
ProgressView(progress.title)
}
}
}
.frame(maxWidth: .infinity, minHeight: 25)
VStack {
ScrollView(.vertical) {
ScrollViewReader { sp in
Group {
Text(evaluator.output)
.textSelection(.enabled)
.frame(maxWidth: .infinity)
}
.onChange(of: evaluator.output) { _, _ in
sp.scrollTo("bottom")
}
.padding()
Spacer()
.frame(width: 1, height: 1)
.id("bottom")
}
}
// controls for each of the different states
VStack {
switch evaluator.state {
case .idle:
Button("Start", action: start)
case .training:
EmptyView()
case .evaluate:
Group {
TextEditor(text: $prompt)
.frame(minHeight: 60)
Button("Evaluate", action: evaluate)
}
.disabled(evaluator.progress != nil)
case .failed(let message):
Text("Failed: \(message)")
.bold()
.foregroundStyle(.red)
}
}
.padding()
.frame(maxWidth: .infinity)
}
}
.padding()
}
func start() {
Task {
await evaluator.start()
}
}
func evaluate() {
Task {
await evaluator.evaluate(prompt: prompt)
}
}
}
/// Progress reporting with a title.
struct Progress: Equatable {
let title: String
let current: Double?
let limit: Double?
}
@Observable
class LoRAEvaluator {
enum State {
case idle
case training
case evaluate
case failed(String)
}
enum ModelState {
case idle
case loaded(LLMModel, Tokenizer)
}
var state = State.idle
var progress: Progress?
var output = ""
private let modelConfiguration = ModelConfiguration.mistral7B4bit
private var model: ModelState = .idle
private let loraLayers = 4
private let learningRate: Float = 1e-5
private let parameters = LoRATrain.Parameters(batchSize: 1, iterations: 200)
private let generateParameters = GenerateParameters(temperature: 0.6, topP: 0.9)
private let evaluateShowEvery = 8
private let maxTokens = 200
private func loadModel() async throws -> (LLMModel, Tokenizer) {
switch self.model {
case .idle:
let name = modelConfiguration.name
await MainActor.run {
progress = .init(title: "Loading \(name)", current: 0, limit: 1)
}
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration) {
progress in
if progress.fractionCompleted < 1.0 {
DispatchQueue.main.sync {
self.progress = .init(
title: "Download \(name)", current: progress.fractionCompleted,
limit: 1.0)
}
}
}
eval(model)
self.model = .loaded(model, tokenizer)
return (model, tokenizer)
case .loaded(let model, let tokenizer):
return (model, tokenizer)
}
}
private func loadLoRAData(name: String) throws -> [String]? {
if let url = Bundle.main.url(forResource: name, withExtension: "jsonl") {
return try LLM.loadLoRAData(url: url)
}
return nil
}
func start() async {
do {
try await startInner()
} catch {
self.state = .failed("Failed: \(error)")
}
}
private func startInner() async throws {
// setup
GPU.set(cacheLimit: 32 * 1024 * 1024)
await MainActor.run {
output = ""
state = .training
}
// load the model
let (model, tokenizer) = try await loadModel()
// apply LoRA adapters and train
guard let layerProvider = model as? LoRAModel else {
state = .failed("Model must implement the LoRALayerProvider protocol")
return
}
LoRATrain.convert(
model: model, layers: Array(layerProvider.loraLinearLayers().suffix(loraLayers)))
let train = try loadLoRAData(name: "train")
let valid = try loadLoRAData(name: "valid")
guard let train, let valid else {
state = .failed("Failed to load train/validation data")
return
}
let optimizer = Adam(learningRate: learningRate)
try await LoRATrain.train(
model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
parameters: parameters
) { progress in
await MainActor.run {
switch progress {
case .train(let i, _, _, _):
self.progress = .init(
title: "Train", current: Double(i), limit: Double(parameters.iterations))
case .validation:
output += "\n"
default:
break
}
output += progress.description + "\n"
}
return .more
}
// done training, test
await MainActor.run {
self.progress = .init(title: "Testing", current: nil, limit: nil)
}
guard let test = try loadLoRAData(name: "test") else {
state = .failed("Failed to load test data")
return
}
let loss = LoRATrain.evaluate(
model: model, dataset: test, tokenizer: tokenizer, batchSize: 1, batchCount: 0)
await MainActor.run {
self.progress = nil
self.output += "\n"
self.output += "Test loss \(loss.formatted()), ppl \(exp(loss).formatted())\n"
self.state = .evaluate
}
}
func evaluate(prompt: String) async {
do {
try await evaluateInner(prompt: prompt)
} catch {
self.state = .failed("Failed: \(error)")
}
}
func evaluateInner(prompt: String) async throws {
await MainActor.run {
self.progress = .init(title: "Evaluating", current: nil, limit: nil)
self.output = ""
}
MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
let (model, tokenizer) = try await loadModel()
// prepare the prompt
let preparedPrompt = modelConfiguration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt)
// evaluate
let result = await LLM.generate(
promptTokens: promptTokens, parameters: generateParameters, model: model,
tokenizer: tokenizer,
didGenerate: { tokens in
if tokens.count % evaluateShowEvery == 0 {
let fullOutput = tokenizer.decode(tokens: tokens)
await MainActor.run {
self.output = fullOutput
}
}
return tokens.count >= maxTokens ? .stop : .more
})
await MainActor.run {
self.output = result.output
self.progress = nil
}
}
}

View File

@@ -0,0 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>com.apple.developer.kernel.increased-memory-limit</key>
<true/>
<key>com.apple.security.app-sandbox</key>
<true/>
<key>com.apple.security.files.user-selected.read-only</key>
<true/>
<key>com.apple.security.network.client</key>
<true/>
</dict>
</plist>

View File

@@ -0,0 +1,12 @@
// Copyright © 2024 Apple Inc.
import SwiftUI
@main
struct LoRATrainingExampleApp: App {
var body: some Scene {
WindowGroup {
ContentView()
}
}
}

View File

@@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,21 @@
# LoRATrainingExample
Example application that:
- downloads the `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model from huggingface
- loads the train/valid/test data from `$SRCROOT/Data/lora` (this is copied into the build but you can imagine how it might be downloaded)
- adds LoRA adapters and trains the model
- let's you evaluate a prompt against the model
This roughly equates to the command line example in [Tools/llm-tool](../../Tools/llm-tool) and
you can read more about LoRA there.
This evaluates the LoRA adapted model rather than a fused model. This doesn't persist
the LoRA weights or the fused model -- it will retrain it each time the program is launched.
### Troubleshooting
The `mlx-community/Mistral-7B-v0.1-hf-4bit-mlx` model requires a little over 4G of
memory to load an train -- this may require ~6G of physical RAM.

100
Data/lora/test.jsonl Normal file
View File

@@ -0,0 +1,100 @@
{"text": "table: 1-10015132-16\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What is terrence ross' nationality\nA: SELECT Nationality FROM 1-10015132-16 WHERE Player = 'Terrence Ross'"}
{"text": "table: 1-10015132-16\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What clu was in toronto 1995-96\nA: SELECT School/Club Team FROM 1-10015132-16 WHERE Years in Toronto = '1995-96'"}
{"text": "table: 1-10015132-16\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: which club was in toronto 2003-06\nA: SELECT School/Club Team FROM 1-10015132-16 WHERE Years in Toronto = '2003-06'"}
{"text": "table: 1-10015132-16\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: how many schools or teams had jalen rose\nA: SELECT COUNT School/Club Team FROM 1-10015132-16 WHERE Player = 'Jalen Rose'"}
{"text": "table: 1-10083598-1\ncolumns: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report\nQ: Where was Assen held?\nA: SELECT Round FROM 1-10083598-1 WHERE Circuit = 'Assen'"}
{"text": "table: 1-10083598-1\ncolumns: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report\nQ: What was the number of race that Kevin Curtain won?\nA: SELECT COUNT No FROM 1-10083598-1 WHERE Pole Position = 'Kevin Curtain'"}
{"text": "table: 1-10083598-1\ncolumns: No, Date, Round, Circuit, Pole Position, Fastest Lap, Race winner, Report\nQ: What was the date of the race in Misano?\nA: SELECT Date FROM 1-10083598-1 WHERE Circuit = 'Misano'"}
{"text": "table: 1-1013129-2\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: How many different positions did Sherbrooke Faucons (qmjhl) provide in the draft?\nA: SELECT COUNT Position FROM 1-1013129-2 WHERE College/junior/club team = 'Sherbrooke Faucons (QMJHL)'"}
{"text": "table: 1-1013129-2\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: What are the nationalities of the player picked from Thunder Bay Flyers (ushl)\nA: SELECT Nationality FROM 1-1013129-2 WHERE College/junior/club team = 'Thunder Bay Flyers (USHL)'"}
{"text": "table: 1-1013129-2\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: How many different college/junior/club teams provided a player to the Washington Capitals NHL Team?\nA: SELECT COUNT College/junior/club team FROM 1-1013129-2 WHERE NHL team = 'Washington Capitals'"}
{"text": "table: 1-1013129-3\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: How many different nationalities do the players of New Jersey Devils come from?\nA: SELECT COUNT Nationality FROM 1-1013129-3 WHERE NHL team = 'New Jersey Devils'"}
{"text": "table: 1-1013129-3\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: What's Dorain Anneck's pick number?\nA: SELECT Pick FROM 1-1013129-3 WHERE Player = 'Dorain Anneck'"}
{"text": "table: 1-1013129-3\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: What is the nationality of the player from Vancouver Canucks?\nA: SELECT Nationality FROM 1-1013129-3 WHERE NHL team = 'Vancouver Canucks'"}
{"text": "table: 1-1013129-3\ncolumns: Pick, Player, Position, Nationality, NHL team, College/junior/club team\nQ: What's the pick number of the player from Springfield Olympics (Nejhl)?\nA: SELECT Pick FROM 1-1013129-3 WHERE College/junior/club team = 'Springfield Olympics (NEJHL)'"}
{"text": "table: 1-1014206-2\ncolumns: #, Shipyard, Laid down, Launched, Commissioned, Fleet, Status\nQ: When were the ships launched that were laid down on september 1, 1964?\nA: SELECT Launched FROM 1-1014206-2 WHERE Laid down = 'September 1, 1964'"}
{"text": "table: 1-1014206-2\ncolumns: #, Shipyard, Laid down, Launched, Commissioned, Fleet, Status\nQ: List the # for ships commissioned on december 18, 1965.\nA: SELECT # FROM 1-1014206-2 WHERE Commissioned = 'December 18, 1965'"}
{"text": "table: 1-1014206-2\ncolumns: #, Shipyard, Laid down, Launched, Commissioned, Fleet, Status\nQ: List the # for ships commissioned on september 30, 1967.\nA: SELECT # FROM 1-1014206-2 WHERE Commissioned = 'September 30, 1967'"}
{"text": "table: 1-1014206-2\ncolumns: #, Shipyard, Laid down, Launched, Commissioned, Fleet, Status\nQ: When were ships laid down that were commissioned on october 29, 1965?\nA: SELECT Laid down FROM 1-1014206-2 WHERE Commissioned = 'October 29, 1965'"}
{"text": "table: 1-1015521-2\ncolumns: Equivalent NATO Rank Code, Rank in Spanish, Rank in English, Commonwealth equivalent, US Air Force equivalent\nQ: What could a spanish coronel be addressed as in the commonwealth military?\nA: SELECT Commonwealth equivalent FROM 1-1015521-2 WHERE Rank in Spanish = 'Coronel'"}
{"text": "table: 1-1015521-2\ncolumns: Equivalent NATO Rank Code, Rank in Spanish, Rank in English, Commonwealth equivalent, US Air Force equivalent\nQ: Give me a list of all spanish officer titles that could receive recognition as group captain in english\nA: SELECT Rank in English FROM 1-1015521-2 WHERE Commonwealth equivalent = 'Group Captain'"}
{"text": "table: 1-1015521-2\ncolumns: Equivalent NATO Rank Code, Rank in Spanish, Rank in English, Commonwealth equivalent, US Air Force equivalent\nQ: If you are a pilot officer in the commonwealth then what will you called as in the US air force?\nA: SELECT US Air Force equivalent FROM 1-1015521-2 WHERE Commonwealth equivalent = 'Pilot Officer'"}
{"text": "table: 1-1015521-2\ncolumns: Equivalent NATO Rank Code, Rank in Spanish, Rank in English, Commonwealth equivalent, US Air Force equivalent\nQ: If you're a major general in the US air force then what ranking will you receive in the commonwealth's air force?\nA: SELECT Commonwealth equivalent FROM 1-1015521-2 WHERE US Air Force equivalent = 'Major General'"}
{"text": "table: 1-1015521-2\ncolumns: Equivalent NATO Rank Code, Rank in Spanish, Rank in English, Commonwealth equivalent, US Air Force equivalent\nQ: If you get a ranking as major in the english military then what would the spanish military address you as? \nA: SELECT Rank in Spanish FROM 1-1015521-2 WHERE Rank in English = 'Major'"}
{"text": "table: 1-10182508-5\ncolumns: Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank., Wrestler, # of reigns, Combined defenses, Combined days\nQ: Which wrestlers have had 2 reigns?\nA: SELECT Wrestler FROM 1-10182508-5 WHERE # of reigns = 2"}
{"text": "table: 1-10182508-5\ncolumns: Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank., Wrestler, # of reigns, Combined defenses, Combined days\nQ: In terms of reigns, what is the lowest number listed?\nA: SELECT MIN # of reigns FROM 1-10182508-5"}
{"text": "table: 1-10182508-5\ncolumns: Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank., Wrestler, # of reigns, Combined defenses, Combined days\nQ: What rank was Bryan Danielson in this chart?\nA: SELECT Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank. FROM 1-10182508-5 WHERE Wrestler = 'Bryan Danielson'"}
{"text": "table: 1-10182508-5\ncolumns: Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank., Wrestler, # of reigns, Combined defenses, Combined days\nQ: How many combined days did Go Shiozaki have?\nA: SELECT Combined days FROM 1-10182508-5 WHERE Wrestler = 'Go Shiozaki'"}
{"text": "table: 1-10182508-5\ncolumns: Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank., Wrestler, # of reigns, Combined defenses, Combined days\nQ: What was Go Shiozaki's rank?\nA: SELECT MIN Rank Each wrestlers total number of days as champion are ranked highest to lowest; wrestlers with the same number mean that they are tied for that certain rank. FROM 1-10182508-5 WHERE Wrestler = 'Go Shiozaki'"}
{"text": "table: 1-1024710-2\ncolumns: Member, Electorate, Province, MPs term, Election date\nQ: Which province is grey and bell electorate in\nA: SELECT Province FROM 1-1024710-2 WHERE Electorate = 'Grey and Bell'"}
{"text": "table: 1-1024710-2\ncolumns: Member, Electorate, Province, MPs term, Election date\nQ: Which province is bay of islands in\nA: SELECT Province FROM 1-1024710-2 WHERE Electorate = 'Bay of Islands'"}
{"text": "table: 1-10294071-1\ncolumns: Player, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L, Ties played, Debut, Years played\nQ: what is the total number of\u00a0total w\u2013l\u00a0where\u00a0doubles w\u2013l\u00a0is 11\u201311\nA: SELECT COUNT Total W\u2013L FROM 1-10294071-1 WHERE Doubles W\u2013L = '11\u201311'"}
{"text": "table: 1-10294071-1\ncolumns: Player, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L, Ties played, Debut, Years played\nQ: what is the total number of\u00a0singles w\u2013l\u00a0where\u00a0doubles w\u2013l\u00a0is 11\u201314\nA: SELECT COUNT Singles W\u2013L FROM 1-10294071-1 WHERE Doubles W\u2013L = '11\u201314'"}
{"text": "table: 1-10294071-1\ncolumns: Player, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L, Ties played, Debut, Years played\nQ: what's the\u00a0total w\u2013l\u00a0where\u00a0player\u00a0is boro jovanovi\u0107 category:articles with hcards\nA: SELECT Total W\u2013L FROM 1-10294071-1 WHERE Player = 'Boro Jovanovi\u0107 Category:Articles with hCards'"}
{"text": "table: 1-10294071-1\ncolumns: Player, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L, Ties played, Debut, Years played\nQ: what is the maximum\u00a0ties played\u00a0where\u00a0player\u00a0is josip palada category:articles with hcards\nA: SELECT MAX Ties played FROM 1-10294071-1 WHERE Player = 'Josip Palada Category:Articles with hCards'"}
{"text": "table: 1-10294071-1\ncolumns: Player, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L, Ties played, Debut, Years played\nQ: what is the total number of\u00a0ties played\u00a0where\u00a0total w\u2013l\u00a0is 38\u201324\nA: SELECT COUNT Ties played FROM 1-10294071-1 WHERE Total W\u2013L = '38\u201324'"}
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: What is the Frequency at the Market/Rank of Burlington - Plattsburgh , Vermont - New York /143?\nA: SELECT COUNT Frequency FROM 1-10333757-1 WHERE Market/Rank = 'Burlington - Plattsburgh , Vermont - New York /143'"}
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: What is the Branding for Group Owner Qantam of Cape Cod, LLC?\nA: SELECT Branding FROM 1-10333757-1 WHERE Group owner = 'Qantam of Cape Cod, LLC'"}
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: What Branding does WRKO calls use?\nA: SELECT Branding FROM 1-10333757-1 WHERE Calls = 'WRKO'"}
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: What is the Format for Branding of 1290 wkbk w281au 104.1?\nA: SELECT Format FROM 1-10333757-1 WHERE Branding = '1290 WKBK W281AU 104.1'"}
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: Which Market/Rank is associated with WCRN calls?\nA: SELECT Market/Rank FROM 1-10333757-1 WHERE Calls = 'WCRN'"}
{"text": "table: 1-10333757-1\ncolumns: Calls, Frequency, Branding, Format, Market/Rank, Timeslot, Group owner\nQ: Which Frequency is used for WEGP calls?\nA: SELECT Frequency FROM 1-10333757-1 WHERE Calls = 'WEGP'"}
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: What is the regulated retail price for the tariff code ff0 prs?\nA: SELECT BTs retail price (regulated) FROM 1-10408617-5 WHERE Tariff code = 'ff0 PRS'"}
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: What is the premium associated with tariff code g9?\nA: SELECT Approx premium FROM 1-10408617-5 WHERE Tariff code = 'g9'"}
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: How many tariff codes have a bts retail price of 2p/min or inclusive?\nA: SELECT COUNT Tariff code FROM 1-10408617-5 WHERE BTs retail price (regulated) = '2p/min or inclusive'"}
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: How many tariff codes have a bts retail price of 2.553p/min?\nA: SELECT COUNT Tariff code FROM 1-10408617-5 WHERE BTs retail price (regulated) = '2.553p/min'"}
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: What prefixes are priced at pence per minute, fixed at all times with a premium of 3p/min?\nA: SELECT Prefixes FROM 1-10408617-5 WHERE Scheme = 'Pence per minute, fixed at all times' AND Approx premium = '3p/min'"}
{"text": "table: 1-10408617-5\ncolumns: Scheme, Tariff code, BTs retail price (regulated), Approx premium, Prefixes\nQ: What is the bts retail price (regulated) for tariff code g10?\nA: SELECT BTs retail price (regulated) FROM 1-10408617-5 WHERE Tariff code = 'g10'"}
{"text": "table: 1-10409754-5\ncolumns: Nominative, Old orthography, New orthography, /e/ or /\u00e6/ (IPA), Tone (Latvian notation: /~/ - level, /^/ - broken), Translation\nQ: What is the tone for gen.sing. plague?\nA: SELECT Tone (Latvian notation: /~/ - level, /^/ - broken) FROM 1-10409754-5 WHERE Translation = 'Gen.Sing. plague'"}
{"text": "table: 1-10432351-1\ncolumns: Star (Pismis24-#), Spectral type, Magnitude (M bol ), Temperature (K), Radius (R \u2609 ), Mass (M \u2609 )\nQ: What is the smallest possible radius?\nA: SELECT MIN Radius (R \u2609 ) FROM 1-10432351-1"}
{"text": "table: 1-10432351-1\ncolumns: Star (Pismis24-#), Spectral type, Magnitude (M bol ), Temperature (K), Radius (R \u2609 ), Mass (M \u2609 )\nQ: What are all the spectral types for star mismis24-# is 1sw?\nA: SELECT Spectral type FROM 1-10432351-1 WHERE Star (Pismis24-#) = '1SW'"}
{"text": "table: 1-10432351-1\ncolumns: Star (Pismis24-#), Spectral type, Magnitude (M bol ), Temperature (K), Radius (R \u2609 ), Mass (M \u2609 )\nQ: If a radius is 10, what is the lowest possible mass?\nA: SELECT MIN Mass (M \u2609 ) FROM 1-10432351-1 WHERE Radius (R \u2609 ) = 10"}
{"text": "table: 1-105344-2\ncolumns: Year, Aircraft kilometers, Departures, Flying hours, Passengers, Seat factor, Employees, Profit/loss\nQ: What percentage of seats were filled in 2006?\nA: SELECT Seat factor FROM 1-105344-2 WHERE Year = 2006"}
{"text": "table: 1-105344-2\ncolumns: Year, Aircraft kilometers, Departures, Flying hours, Passengers, Seat factor, Employees, Profit/loss\nQ: How many hours were flown in each of the years where more than 64379058.0 kilometers were flown?\nA: SELECT Flying hours FROM 1-105344-2 WHERE Aircraft kilometers > 64379058.0"}
{"text": "table: 1-105344-2\ncolumns: Year, Aircraft kilometers, Departures, Flying hours, Passengers, Seat factor, Employees, Profit/loss\nQ: Of the years that had exactly 17096 departures, what is the greatest number of aircraft kilometers flown?\nA: SELECT MAX Aircraft kilometers FROM 1-105344-2 WHERE Departures = 17096"}
{"text": "table: 1-10548224-1\ncolumns: Year, Game or event, Date contested, League or governing body, Sport, Winning team, Losing team, Final score\nQ: Which winning team beat the New York Yankees?\nA: SELECT Winning team FROM 1-10548224-1 WHERE Losing team = 'New York Yankees'"}
{"text": "table: 1-10548224-1\ncolumns: Year, Game or event, Date contested, League or governing body, Sport, Winning team, Losing team, Final score\nQ: What was the final score for the game that was contested on February 1, 2009?\nA: SELECT Final score FROM 1-10548224-1 WHERE Date contested = 'February 1, 2009'"}
{"text": "table: 1-10548224-1\ncolumns: Year, Game or event, Date contested, League or governing body, Sport, Winning team, Losing team, Final score\nQ: What sport had a final score of 3-2?\nA: SELECT Sport FROM 1-10548224-1 WHERE Final score = '3-2'"}
{"text": "table: 1-10548224-1\ncolumns: Year, Game or event, Date contested, League or governing body, Sport, Winning team, Losing team, Final score\nQ: Who was the winning team of the game that was contested on February 1, 2009?\nA: SELECT Winning team FROM 1-10548224-1 WHERE Date contested = 'February 1, 2009'"}
{"text": "table: 1-10548224-1\ncolumns: Year, Game or event, Date contested, League or governing body, Sport, Winning team, Losing team, Final score\nQ: Who was the losing team of the game that was contested on February 1, 2004?\nA: SELECT Losing team FROM 1-10548224-1 WHERE Date contested = 'February 1, 2004'"}
{"text": "table: 1-1057262-2\ncolumns: Crop (kilotonnes), New South Wales, Victoria, Queensland, Western Australia, South Australia, Tasmania, Total\nQ: what's the minimum\u00a0total\u00a0with\u00a0crop (kilotonnes)\u00a0being s lupin\nA: SELECT MIN Total FROM 1-1057262-2 WHERE Crop (kilotonnes) = 's Lupin'"}
{"text": "table: 1-1057262-2\ncolumns: Crop (kilotonnes), New South Wales, Victoria, Queensland, Western Australia, South Australia, Tasmania, Total\nQ: what's the\u00a0new south wales\u00a0with\u00a0crop (kilotonnes)\u00a0being canola\nA: SELECT New South Wales FROM 1-1057262-2 WHERE Crop (kilotonnes) = 'Canola'"}
{"text": "table: 1-1057262-2\ncolumns: Crop (kilotonnes), New South Wales, Victoria, Queensland, Western Australia, South Australia, Tasmania, Total\nQ: what's the total number of\u00a0south australia\u00a0with\u00a0victoria\u00a0value of 2173\nA: SELECT COUNT South Australia FROM 1-1057262-2 WHERE Victoria = 2173"}
{"text": "table: 1-1057262-2\ncolumns: Crop (kilotonnes), New South Wales, Victoria, Queensland, Western Australia, South Australia, Tasmania, Total\nQ: what's the minimum\u00a0tasmania value\nA: SELECT MIN Tasmania FROM 1-1057262-2"}
{"text": "table: 1-1057262-2\ncolumns: Crop (kilotonnes), New South Wales, Victoria, Queensland, Western Australia, South Australia, Tasmania, Total\nQ: what's the total number of\u00a0tasmania\u00a0with\u00a0new south wales\u00a0crop of 190 kilotonnes\nA: SELECT COUNT Tasmania FROM 1-1057262-2 WHERE New South Wales = 190"}
{"text": "table: 1-1058787-1\ncolumns: Approximate Age, Virtues, Psycho Social Crisis, Significant Relationship, Existential Question [ not in citation given ], Examples\nQ: How many significant relationships list Will as a virtue?\nA: SELECT COUNT Significant Relationship FROM 1-1058787-1 WHERE Virtues = 'Will'"}
{"text": "table: 1-1058787-1\ncolumns: Approximate Age, Virtues, Psycho Social Crisis, Significant Relationship, Existential Question [ not in citation given ], Examples\nQ: Which examples ask the existential question \"Can I Love?\"\nA: SELECT Examples FROM 1-1058787-1 WHERE Existential Question [ not in citation given ] = 'Can I Love?'"}
{"text": "table: 1-1059743-2\ncolumns: Rank, Member Association, Points, Group stage, Play-off, AFC Cup\nQ: How many countries got 796.7 points?\nA: SELECT COUNT Rank FROM 1-1059743-2 WHERE Points = '796.7'"}
{"text": "table: 1-1059743-2\ncolumns: Rank, Member Association, Points, Group stage, Play-off, AFC Cup\nQ: In what group stage were 177.2 points awarded?\nA: SELECT COUNT Group stage FROM 1-1059743-2 WHERE Points = '177.2'"}
{"text": "table: 1-1059743-2\ncolumns: Rank, Member Association, Points, Group stage, Play-off, AFC Cup\nQ: What is the lowest group to earn 886.6 points?\nA: SELECT MIN Group stage FROM 1-1059743-2 WHERE Points = '886.6'"}
{"text": "table: 1-1059743-2\ncolumns: Rank, Member Association, Points, Group stage, Play-off, AFC Cup\nQ: How many countries earned 177.2 points?\nA: SELECT COUNT Member Association FROM 1-1059743-2 WHERE Points = '177.2'"}
{"text": "table: 1-10586064-2\ncolumns: County, Precincts, Lunsford, % Lunsford, McConnell, % McConnell, Total\nQ: If % lunsford is 51.82% what is the % mcconnell in Letcher?\nA: SELECT % McConnell FROM 1-10586064-2 WHERE % Lunsford = '51.82%'"}
{"text": "table: 1-10586064-2\ncolumns: County, Precincts, Lunsford, % Lunsford, McConnell, % McConnell, Total\nQ: What country had the total 18,900 (r)?\nA: SELECT County FROM 1-10586064-2 WHERE Total = '18,900 (R)'"}
{"text": "table: 1-10586064-2\ncolumns: County, Precincts, Lunsford, % Lunsford, McConnell, % McConnell, Total\nQ: When % mcconnell is 44.54% what are the total number of counties?\nA: SELECT COUNT County FROM 1-10586064-2 WHERE % McConnell = '44.54%'"}
{"text": "table: 1-10586064-2\ncolumns: County, Precincts, Lunsford, % Lunsford, McConnell, % McConnell, Total\nQ: If % mcconnell is 47.17% what is the total number of mcconnell ?\nA: SELECT COUNT McConnell FROM 1-10586064-2 WHERE % McConnell = '47.17%'"}
{"text": "table: 1-10586064-2\ncolumns: County, Precincts, Lunsford, % Lunsford, McConnell, % McConnell, Total\nQ: What is the county of precints 515?\nA: SELECT County FROM 1-10586064-2 WHERE Precincts = 515"}
{"text": "table: 1-10601843-2\ncolumns: Stadium, Capacity, City, Country, Tenant, Opening\nQ: Which city has a capacity of 41903?\nA: SELECT City FROM 1-10601843-2 WHERE Capacity = 41903"}
{"text": "table: 1-10601843-2\ncolumns: Stadium, Capacity, City, Country, Tenant, Opening\nQ: What is the maximum capacity of the Otkrytie Arena stadium?\nA: SELECT MAX Capacity FROM 1-10601843-2 WHERE Stadium = 'Otkrytie Arena'"}
{"text": "table: 1-10601843-2\ncolumns: Stadium, Capacity, City, Country, Tenant, Opening\nQ: When did the stadium where Bursaspor is the tenant open?\nA: SELECT MIN Opening FROM 1-10601843-2 WHERE Tenant = 'Bursaspor'"}
{"text": "table: 1-10601843-2\ncolumns: Stadium, Capacity, City, Country, Tenant, Opening\nQ: How many tenants are there in the city of Samsun?\nA: SELECT COUNT Tenant FROM 1-10601843-2 WHERE City = 'Samsun'"}
{"text": "table: 1-10610087-5\ncolumns: No. in series, No. in season, Title, Directed by, Written by, Original air date\nQ: what's the\u00a0original air date\u00a0with\u00a0title\u00a0 \"hell\"\nA: SELECT Original air date FROM 1-10610087-5 WHERE Title = '\"Hell\"'"}
{"text": "table: 1-10638523-1\ncolumns: Particulars and Characteristics, Shivalik Zone, Mid-Hill Zone, High hill zone, Trance- n Himalaya Zone\nQ: What is the percentage of the Shivalik Zone where the percentage of the Mid-Hill Zone is 10%?\nA: SELECT Shivalik Zone FROM 1-10638523-1 WHERE Mid-Hill Zone = '10%'"}
{"text": "table: 1-10638523-1\ncolumns: Particulars and Characteristics, Shivalik Zone, Mid-Hill Zone, High hill zone, Trance- n Himalaya Zone\nQ: For mid-hill zone what is the altitude?\nA: SELECT Mid-Hill Zone FROM 1-10638523-1 WHERE Particulars and Characteristics = 'Altitude'"}
{"text": "table: 1-10638523-1\ncolumns: Particulars and Characteristics, Shivalik Zone, Mid-Hill Zone, High hill zone, Trance- n Himalaya Zone\nQ: What are the climatic conditions for the trance- n himalaya zone?\nA: SELECT Trance- n Himalaya Zone FROM 1-10638523-1 WHERE Particulars and Characteristics = 'Climatic conditions'"}
{"text": "table: 1-10638523-1\ncolumns: Particulars and Characteristics, Shivalik Zone, Mid-Hill Zone, High hill zone, Trance- n Himalaya Zone\nQ: What is the percentage of the trance- n himalaya zone that corresponds with the high hill zone is 25%?\nA: SELECT Trance- n Himalaya Zone FROM 1-10638523-1 WHERE High hill zone = '25%'"}
{"text": "table: 1-10644188-3\ncolumns: Total tenure rank, Uninterrupted rank, Name, State represented, Dates of service, Total tenure time, Uninterrupted time\nQ: What is the state of Ted Stevens?\nA: SELECT State represented FROM 1-10644188-3 WHERE Name = 'Ted Stevens'"}
{"text": "table: 1-10682862-68\ncolumns: Country, Players, Standard, Minor, First title, Last title\nQ: What's the standard of the country who won its first title in 1992?\nA: SELECT MAX Standard FROM 1-10682862-68 WHERE First title = 1992"}
{"text": "table: 1-10682862-68\ncolumns: Country, Players, Standard, Minor, First title, Last title\nQ: What's the smallest number of players?\nA: SELECT MIN Players FROM 1-10682862-68"}
{"text": "table: 1-10682862-68\ncolumns: Country, Players, Standard, Minor, First title, Last title\nQ: In what year was the last last title received, by any of the countries?\nA: SELECT MAX Last title FROM 1-10682862-68"}
{"text": "table: 1-10710364-1\ncolumns: Religious group, Population % 1961, Population % 1971, Population % 1981, Population % 1991, Population % 2001\nQ: What religious groups made up 0.72% of the Indian population in 2001?\nA: SELECT Religious group FROM 1-10710364-1 WHERE Population % 2001 = '0.72%'"}
{"text": "table: 1-10718868-2\ncolumns: No. in series, No. in season, Title, Directed by, Written by, Original air date, U.S. viewers (millions)\nQ: What is the original air date for episode 15 of season 6?\nA: SELECT Original air date FROM 1-10718868-2 WHERE No. in season = 15"}
{"text": "table: 1-10718868-2\ncolumns: No. in series, No. in season, Title, Directed by, Written by, Original air date, U.S. viewers (millions)\nQ: How many episodes in season 6 titles \"Poppin' Tags\"?\nA: SELECT COUNT No. in season FROM 1-10718868-2 WHERE Title = '\"Poppin' Tags\"'"}
{"text": "table: 1-10753917-1\ncolumns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat\nQ: Which podiums did the Williams team have with a margin of defeat of 2?\nA: SELECT Podiums FROM 1-10753917-1 WHERE Team = 'Williams' AND Margin of defeat = '2'"}
{"text": "table: 1-10753917-1\ncolumns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat\nQ: How many drivers on the williams team had a margin of defeat of 2?\nA: SELECT COUNT Driver FROM 1-10753917-1 WHERE Team = 'Williams' AND Margin of defeat = '2'"}
{"text": "table: 1-10753917-1\ncolumns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat\nQ: How many seasons was clay regazzoni the driver?\nA: SELECT COUNT Season FROM 1-10753917-1 WHERE Driver = 'Clay Regazzoni'"}
{"text": "table: 1-10753917-1\ncolumns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat\nQ: Which margin of defeats had points of 30?\nA: SELECT Margin of defeat FROM 1-10753917-1 WHERE Points = '30'"}
{"text": "table: 1-10753917-1\ncolumns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat\nQ: Which podiums did the alfa romeo team have?\nA: SELECT Podiums FROM 1-10753917-1 WHERE Team = 'Alfa Romeo'"}
{"text": "table: 1-10797636-1\ncolumns: Village (German), Village (Slovene), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: What was the percent of slovenes 1951 for bach?\nA: SELECT Percent of Slovenes 1951 FROM 1-10797636-1 WHERE Village (German) = 'Bach'"}
{"text": "table: 1-10812403-4\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: What college's team is the Saskatchewan Roughriders?\nA: SELECT College FROM 1-10812403-4 WHERE CFL Team = 'Saskatchewan Roughriders'"}
{"text": "table: 1-10812403-4\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: What position did Calvin Mccarty play?\nA: SELECT Position FROM 1-10812403-4 WHERE Player = 'Calvin McCarty'"}
{"text": "table: 1-10812403-4\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: How many people were pick #30?\nA: SELECT COUNT Position FROM 1-10812403-4 WHERE Pick # = 30"}

1000
Data/lora/train.jsonl Normal file

File diff suppressed because it is too large Load Diff

100
Data/lora/valid.jsonl Normal file
View File

@@ -0,0 +1,100 @@
{"text": "table: 1-10015132-11\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What position does the player who played for butler cc (ks) play?\nA: SELECT Position FROM 1-10015132-11 WHERE School/Club Team = 'Butler CC (KS)'"}
{"text": "table: 1-10015132-11\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: How many schools did player number 3 play at?\nA: SELECT COUNT School/Club Team FROM 1-10015132-11 WHERE No. = '3'"}
{"text": "table: 1-10015132-11\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What school did player number 21 play for?\nA: SELECT School/Club Team FROM 1-10015132-11 WHERE No. = '21'"}
{"text": "table: 1-10015132-11\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: Who is the player that wears number 42?\nA: SELECT Player FROM 1-10015132-11 WHERE No. = '42'"}
{"text": "table: 1-10015132-11\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What player played guard for toronto in 1996-97?\nA: SELECT Player FROM 1-10015132-11 WHERE Position = 'Guard' AND Years in Toronto = '1996-97'"}
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: Who are all of the players on the Westchester High School club team?\nA: SELECT Player FROM 1-10015132-9 WHERE School/Club Team = 'Westchester High School'"}
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What school/club team is Amir Johnson on?\nA: SELECT School/Club Team FROM 1-10015132-9 WHERE Player = 'Amir Johnson'"}
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What are the total amount of numbers on the Toronto team in 2005-06?\nA: SELECT COUNT No. FROM 1-10015132-9 WHERE Years in Toronto = '2005-06'"}
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What are the total number of positions on the Toronto team in 2006-07?\nA: SELECT COUNT Position FROM 1-10015132-9 WHERE Years in Toronto = '2006-07'"}
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What are the nationality of the players on the Fresno State school/club team?\nA: SELECT Nationality FROM 1-10015132-9 WHERE School/Club Team = 'Fresno State'"}
{"text": "table: 1-10015132-9\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What school/club team is Trey Johnson on?\nA: SELECT School/Club Team FROM 1-10015132-9 WHERE Player = 'Trey Johnson'"}
{"text": "table: 1-10026563-1\ncolumns: Entered office as Head of State or Government, Began time as senior G8 leader, Ended time as senior G8 leader, Person, Office\nQ: When did Jacques Chirac stop being a G8 leader?\nA: SELECT Ended time as senior G8 leader FROM 1-10026563-1 WHERE Person = 'Jacques Chirac'"}
{"text": "table: 1-10026563-1\ncolumns: Entered office as Head of State or Government, Began time as senior G8 leader, Ended time as senior G8 leader, Person, Office\nQ: When did the Prime Minister of Italy take office?\nA: SELECT Entered office as Head of State or Government FROM 1-10026563-1 WHERE Office = 'Prime Minister of Italy'"}
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: What is the English name of the country whose official native language is Dutch Papiamento?\nA: SELECT Country ( exonym ) FROM 1-1008653-1 WHERE Official or native language(s) (alphabet/script) = 'Dutch Papiamento'"}
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: What official or native languages are spoken in the country whose capital city is Canberra?\nA: SELECT Official or native language(s) (alphabet/script) FROM 1-1008653-1 WHERE Capital ( exonym ) = 'Canberra'"}
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: What is the local name given to the city of Canberra?\nA: SELECT Capital ( endonym ) FROM 1-1008653-1 WHERE Capital ( exonym ) = 'Canberra'"}
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: What is the local name given to the capital of Anguilla?\nA: SELECT Capital ( endonym ) FROM 1-1008653-1 WHERE Country ( endonym ) = 'Anguilla'"}
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: What is the English name given to the city of St. John's?\nA: SELECT Capital ( exonym ) FROM 1-1008653-1 WHERE Capital ( endonym ) = 'St. John's'"}
{"text": "table: 1-1008653-1\ncolumns: Country ( exonym ), Capital ( exonym ), Country ( endonym ), Capital ( endonym ), Official or native language(s) (alphabet/script)\nQ: How many capital cities does Australia have?\nA: SELECT COUNT Capital ( endonym ) FROM 1-1008653-1 WHERE Country ( endonym ) = 'Australia'"}
{"text": "table: 1-10088101-1\ncolumns: No. in set, No. in series, Title, Directed by, Written by, Original air date, Production code\nQ: The episode with production code 9abx02 was originally aired on what date?\nA: SELECT Original air date FROM 1-10088101-1 WHERE Production code = '9ABX02'"}
{"text": "table: 1-10088101-1\ncolumns: No. in set, No. in series, Title, Directed by, Written by, Original air date, Production code\nQ: What is the episode number that has production code 8abx15?\nA: SELECT MIN No. in series FROM 1-10088101-1 WHERE Production code = '8ABX15'"}
{"text": "table: 1-10295819-2\ncolumns: Player, Highest singles ranking, Highest doubles ranking, First year played, Years played, Ties played, Total W\u2013L, Singles W\u2013L, Doubles W\u2013L\nQ: Name the minimum tiesplayed for 6 years\nA: SELECT MIN Ties played FROM 1-10295819-2 WHERE Years played = 6"}
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the amount of trees, that require replacement when prevailing types, % is pine \u2014 29.37 poplar \u2014 26.12 acer negundo \u2014 13.2?\nA: SELECT Amount of trees, that require replacement FROM 1-10342194-3 WHERE Prevailing types, % = 'Pine \u2014 29.37 Poplar \u2014 26.12 Acer negundo \u2014 13.2'"}
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the amount of trees, that require replacement when district is leninsky?\nA: SELECT Amount of trees, that require replacement FROM 1-10342194-3 WHERE District = 'Leninsky'"}
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the district when the total amount of trees is smaller than 150817.6878461314 and amount of old trees is 1,928 (1.89%)?\nA: SELECT District FROM 1-10342194-3 WHERE Total amount of trees < 150817.6878461314 AND Amount of old trees = '1,928 (1.89%)'"}
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the amount of trees, that require replacement when the district is motovilikhinsky?\nA: SELECT Amount of trees, that require replacement FROM 1-10342194-3 WHERE District = 'Motovilikhinsky'"}
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the total amount of trees when district is leninsky?\nA: SELECT MAX Total amount of trees FROM 1-10342194-3 WHERE District = 'Leninsky'"}
{"text": "table: 1-10342194-3\ncolumns: District, Total amount of trees, Prevailing types, %, Amount of old trees, Amount of trees, that require replacement\nQ: What is the district when prevailing types, % is acer negundo \u2014 30.22 tilia \u2014 18.6 poplar \u2014 15.23?\nA: SELECT District FROM 1-10342194-3 WHERE Prevailing types, % = 'Acer negundo \u2014 30.22 Tilia \u2014 18.6 Poplar \u2014 15.23'"}
{"text": "table: 1-10429820-13\ncolumns: Iowa State vs., Overall Record, in Ames, at Opponents Venue, at Neutral Site, Last 5 Meetings, Last 10 Meetings, Current Streak, Since Beginning of Big 12\nQ: When the value of \"since beginning of big 12\" is synonymous with its' category, what are the in Ames values?\nA: SELECT in Ames FROM 1-10429820-13 WHERE Since Beginning of Big 12 = 'Since Beginning of Big 12'"}
{"text": "table: 1-1046170-5\ncolumns: Year, Division, League, Regular Season, Playoffs, U.S. Open Cup\nQ: what's the\u00a0u.s. open cup status\u00a0for regular season\u00a0of 4th, atlantic division \nA: SELECT U.S. Open Cup FROM 1-1046170-5 WHERE Regular Season = '4th, Atlantic Division'"}
{"text": "table: 1-1046170-5\ncolumns: Year, Division, League, Regular Season, Playoffs, U.S. Open Cup\nQ: how many division did not qualify for u.s. open cup in 2003\nA: SELECT Division FROM 1-1046170-5 WHERE U.S. Open Cup = 'Did Not Qualify' AND Year = 2003"}
{"text": "table: 1-1046170-5\ncolumns: Year, Division, League, Regular Season, Playoffs, U.S. Open Cup\nQ: which round is u.s. open cup division semifinals\nA: SELECT U.S. Open Cup FROM 1-1046170-5 WHERE Playoffs = 'Division Semifinals'"}
{"text": "table: 1-1046170-5\ncolumns: Year, Division, League, Regular Season, Playoffs, U.S. Open Cup\nQ: what are all the playoffs for regular season is 1st, atlantic division\nA: SELECT Playoffs FROM 1-1046170-5 WHERE Regular Season = '1st, Atlantic Division'"}
{"text": "table: 1-1046170-5\ncolumns: Year, Division, League, Regular Season, Playoffs, U.S. Open Cup\nQ: what are all the playoffs for u.s. open cup in 1st round\nA: SELECT Playoffs FROM 1-1046170-5 WHERE U.S. Open Cup = '1st Round'"}
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what is the total number of\u00a02nd leg\u00a0where\u00a0aggregate\u00a0is 7-2\nA: SELECT COUNT 2nd leg FROM 1-1061075-1 WHERE Aggregate = '7-2'"}
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what's the\u00a0aggregate\u00a0where\u00a01st leg\u00a0is 3\u20132\nA: SELECT Aggregate FROM 1-1061075-1 WHERE 1st leg = '3\u20132'"}
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what's the\u00a0competition\u00a0where\u00a0aggregate\u00a0is 4\u20137\nA: SELECT Competition FROM 1-1061075-1 WHERE Aggregate = '4\u20137'"}
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what's the\u00a0competition\u00a0where\u00a01st leg\u00a0is 4-1 (h)\nA: SELECT Competition FROM 1-1061075-1 WHERE 1st leg = '4-1 (h)'"}
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what is the total number of\u00a0round\u00a0where\u00a0opponents\u00a0is haugar\nA: SELECT COUNT Round FROM 1-1061075-1 WHERE Opponents = 'Haugar'"}
{"text": "table: 1-1061075-1\ncolumns: Season, Competition, Round, Opponents, 1st leg, 2nd leg, Aggregate\nQ: what's the\u00a01st leg\u00a0where\u00a0opponents\u00a0is galatasaray\nA: SELECT 1st leg FROM 1-1061075-1 WHERE Opponents = 'Galatasaray'"}
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: What is the highest Rd that Tom Sneva had the pole position in?\nA: SELECT MAX Rd FROM 1-10706961-2 WHERE Pole Position = 'Tom Sneva'"}
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: How many winning drivers were there in the race that had a fastest lap time of 56.920?\nA: SELECT COUNT Winning driver FROM 1-10706961-2 WHERE Fastest Lap = '56.920'"}
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: How many reports are there in the race that Forsythe Racing won and Teo Fabi had the pole position in?\nA: SELECT COUNT Report FROM 1-10706961-2 WHERE Winning team = 'Forsythe Racing' AND Pole Position = 'Teo Fabi'"}
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: Which Rd took place at the Indianapolis 500?\nA: SELECT Rd FROM 1-10706961-2 WHERE Name = 'Indianapolis 500'"}
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: Which teams won when Bobby Rahal was their winning driver?\nA: SELECT Winning team FROM 1-10706961-2 WHERE Winning driver = 'Bobby Rahal'"}
{"text": "table: 1-10706961-2\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: What was the fastest lap time in the Escort Radar Warning 200?\nA: SELECT Fastest Lap FROM 1-10706961-2 WHERE Name = 'Escort Radar Warning 200'"}
{"text": "table: 1-10707176-2\ncolumns: Rnd, Race Name, Circuit, City/Location, Date, Pole position, Winning driver, Winning team, Report\nQ: What report was there for the porsche north america?\nA: SELECT Report FROM 1-10707176-2 WHERE Winning team = 'Porsche North America'"}
{"text": "table: 1-10707176-2\ncolumns: Rnd, Race Name, Circuit, City/Location, Date, Pole position, Winning driver, Winning team, Report\nQ: What rnds were there for the phoenix international raceway?\nA: SELECT Rnd FROM 1-10707176-2 WHERE Circuit = 'Phoenix International Raceway'"}
{"text": "table: 1-10707176-2\ncolumns: Rnd, Race Name, Circuit, City/Location, Date, Pole position, Winning driver, Winning team, Report\nQ: Who was the pole position for the rnd equalling 12?\nA: SELECT Pole position FROM 1-10707176-2 WHERE Rnd = '12'"}
{"text": "table: 1-10707176-2\ncolumns: Rnd, Race Name, Circuit, City/Location, Date, Pole position, Winning driver, Winning team, Report\nQ: How many reports were the for the cleveland burke lakefront airport circut?\nA: SELECT COUNT Report FROM 1-10707176-2 WHERE Circuit = 'Cleveland Burke Lakefront Airport'"}
{"text": "table: 1-10707176-2\ncolumns: Rnd, Race Name, Circuit, City/Location, Date, Pole position, Winning driver, Winning team, Report\nQ: How many winning drivers were the for the rnd equalling 5?\nA: SELECT COUNT Winning driver FROM 1-10707176-2 WHERE Rnd = '5'"}
{"text": "table: 1-10706879-3\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: The race tony bettenhausen 200 has what smallest rd?\nA: SELECT MIN Rd FROM 1-10706879-3 WHERE Name = 'Tony Bettenhausen 200'"}
{"text": "table: 1-10706879-3\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: The winning team of the race, los angeles times 500 is who?\nA: SELECT Winning team FROM 1-10706879-3 WHERE Name = 'Los Angeles Times 500'"}
{"text": "table: 1-10706879-3\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: How many winning drivers in the kraco twin 125 (r2) race were there?\nA: SELECT COUNT Winning driver FROM 1-10706879-3 WHERE Name = 'Kraco Twin 125 (R2)'"}
{"text": "table: 1-10706879-3\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: What are the races that johnny rutherford has won?\nA: SELECT Name FROM 1-10706879-3 WHERE Winning driver = 'Johnny Rutherford'"}
{"text": "table: 1-10706879-3\ncolumns: Rd, Name, Pole Position, Fastest Lap, Winning driver, Winning team, Report\nQ: How many fastest laps were there for a rd that equals 10?\nA: SELECT COUNT Fastest Lap FROM 1-10706879-3 WHERE Rd = 10"}
{"text": "table: 1-10712301-5\ncolumns: Region, Operator, Licence award date, On air date, Closure date\nQ: What is the license award date for North East England?\nA: SELECT Licence award date FROM 1-10712301-5 WHERE Region = 'North East England'"}
{"text": "table: 1-10733530-3\ncolumns: Nation, Population (thousands), Internet subscriptions (2000) (thousands of users), Internet subscriptions (2008) (thousands of users), % growth (2000\u20132008), % Internet users\nQ: What is the percentage of growth in 2000-2008 in ethiopia?\nA: SELECT % growth (2000\u20132008) FROM 1-10733530-3 WHERE Nation = 'Ethiopia'"}
{"text": "table: 1-10733530-3\ncolumns: Nation, Population (thousands), Internet subscriptions (2000) (thousands of users), Internet subscriptions (2008) (thousands of users), % growth (2000\u20132008), % Internet users\nQ: Name the total number of percentage growth 2000-2008 of uganda?\nA: SELECT COUNT % growth (2000\u20132008) FROM 1-10733530-3 WHERE Nation = 'Uganda'"}
{"text": "table: 1-10733530-3\ncolumns: Nation, Population (thousands), Internet subscriptions (2000) (thousands of users), Internet subscriptions (2008) (thousands of users), % growth (2000\u20132008), % Internet users\nQ: What is the maximum percentage grown 2000-2008 in burundi\nA: SELECT MAX % growth (2000\u20132008) FROM 1-10733530-3 WHERE Nation = 'Burundi'"}
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Provide me with the names of all the villages (German) that has 76.3% of Slovenes in 1951.\nA: SELECT Village (German) FROM 1-10798421-1 WHERE Percent of Slovenes 1951 = '76.3%'"}
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Give me the minimum number of people in 1991 with 92.5% of Slovenes in 1991.\nA: SELECT MIN Number of people 1991 FROM 1-10798421-1 WHERE Percent of Slovenes 1991 = '92.5%'"}
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Provide me with the name of all the village (German) that are part of the village (Slovenian) with sele srednji kot. \nA: SELECT Village (German) FROM 1-10798421-1 WHERE Village (Slovenian) = 'Sele Srednji Kot'"}
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Provide me with the name of all the village (German) that are part of the village (Slovenian) with sele borovnica.\nA: SELECT Village (German) FROM 1-10798421-1 WHERE Village (Slovenian) = 'Sele Borovnica'"}
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Provide me with the name of the village (German) where there is 96.9% Slovenes in 1951. \nA: SELECT Village (German) FROM 1-10798421-1 WHERE Percent of Slovenes 1951 = '96.9%'"}
{"text": "table: 1-10798421-1\ncolumns: Village (German), Village (Slovenian), Number of people 1991, Percent of Slovenes 1991, Percent of Slovenes 1951\nQ: Provide with the names of the village (German) that is part of village (Slovenian) with sele srednji kot.\nA: SELECT Village (German) FROM 1-10798421-1 WHERE Village (Slovenian) = 'Sele Srednji Kot'"}
{"text": "table: 1-10812293-3\ncolumns: Game, Date, Team, Score, High points, High rebounds, High assists, Location Attendance, Record\nQ: What was the score of the game on November 12?\nA: SELECT Score FROM 1-10812293-3 WHERE Date = 'November 12'"}
{"text": "table: 1-10812293-3\ncolumns: Game, Date, Team, Score, High points, High rebounds, High assists, Location Attendance, Record\nQ: Who had high assists when they played against San Antonio?\nA: SELECT High assists FROM 1-10812293-3 WHERE Team = 'San Antonio'"}
{"text": "table: 1-10812293-3\ncolumns: Game, Date, Team, Score, High points, High rebounds, High assists, Location Attendance, Record\nQ: Who scored the most points in game 4?\nA: SELECT High points FROM 1-10812293-3 WHERE Game = 4"}
{"text": "table: 1-10812293-3\ncolumns: Game, Date, Team, Score, High points, High rebounds, High assists, Location Attendance, Record\nQ: Where was the game on November 20?\nA: SELECT Location Attendance FROM 1-10812293-3 WHERE Date = 'November 20'"}
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: The canadian airdate of 11 february 2008 applied to what series number?\nA: SELECT COUNT No. in series FROM 1-10935205-1 WHERE Canadian airdate = '11 February 2008'"}
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: The U.S. airdate of 4 april 2008 had a production code of what?\nA: SELECT MAX Production code FROM 1-10935205-1 WHERE US airdate = '4 April 2008'"}
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: The episode titled \"don't stop believin'\" was what highest number of the season?\nA: SELECT MAX No. in season FROM 1-10935205-1 WHERE Title = '\"Don't Stop Believin'\"'"}
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: The U.S. airdate of 8 august 2008 also had canadian airdates of what?\nA: SELECT Canadian airdate FROM 1-10935205-1 WHERE US airdate = '8 August 2008'"}
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: The canadian airdate of 17 march 2008 had how many numbers in the season?\nA: SELECT COUNT No. in season FROM 1-10935205-1 WHERE Canadian airdate = '17 March 2008'"}
{"text": "table: 1-10935205-1\ncolumns: No. in season, No. in series, Title, Canadian airdate, US airdate, Production code\nQ: For the episode(s) aired in the U.S. on 4 april 2008, what were the names?\nA: SELECT Title FROM 1-10935205-1 WHERE US airdate = '4 April 2008'"}
{"text": "table: 1-10953197-5\ncolumns: No. in series, No. in season, Title, Director, Writer(s), Original air date, Production code\nQ: Who directed the episode \"Great Sexpectations (2)\"?\nA: SELECT Director FROM 1-10953197-5 WHERE Title = '\"Great Sexpectations (2)\"'"}
{"text": "table: 1-10975034-2\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: Which player from the 2004 CFL draft attended Wilfrid Laurier?\nA: SELECT Player FROM 1-10975034-2 WHERE College = 'Wilfrid Laurier'"}
{"text": "table: 1-10975034-2\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: What position does Christian Leibl-Cote play?\nA: SELECT Position FROM 1-10975034-2 WHERE Player = 'Christian Leibl-Cote'"}
{"text": "table: 1-10975034-2\ncolumns: Pick #, CFL Team, Player, Position, College\nQ: What is the pick number for Northwestern college?\nA: SELECT MAX Pick # FROM 1-10975034-2 WHERE College = 'Northwestern'"}
{"text": "table: 1-10992-3\ncolumns: No, City district (Stadtteil), Area in km\u00b2, Population, Foreign nationals, Foreign nationals in %, Area district (Ortsbezirk)\nQ: How many foreigners in percentage terms had a population of 4.911?\nA: SELECT COUNT Foreign nationals in % FROM 1-10992-3 WHERE Population = '4.911'"}
{"text": "table: 1-10992-3\ncolumns: No, City district (Stadtteil), Area in km\u00b2, Population, Foreign nationals, Foreign nationals in %, Area district (Ortsbezirk)\nQ: What is the number of the city district of stadtteil where foreigners are 5.162?\nA: SELECT COUNT City district (Stadtteil) FROM 1-10992-3 WHERE Foreign nationals = '5.162'"}
{"text": "table: 1-10992-3\ncolumns: No, City district (Stadtteil), Area in km\u00b2, Population, Foreign nationals, Foreign nationals in %, Area district (Ortsbezirk)\nQ: What is the city where the number is 47?\nA: SELECT City district (Stadtteil) FROM 1-10992-3 WHERE No = '47'"}
{"text": "table: 1-11044765-1\ncolumns: School, Mascot, Location, League, Enrollment\nQ: Which leagues have Raiders as their mascot?\nA: SELECT League FROM 1-11044765-1 WHERE Mascot = 'Raiders'"}
{"text": "table: 1-11044765-1\ncolumns: School, Mascot, Location, League, Enrollment\nQ: Which leagues is the Galena school in?\nA: SELECT League FROM 1-11044765-1 WHERE School = 'Galena'"}
{"text": "table: 1-11044765-1\ncolumns: School, Mascot, Location, League, Enrollment\nQ: What city and state is the Lancers mascot located?\nA: SELECT Location FROM 1-11044765-1 WHERE Mascot = 'Lancers'"}
{"text": "table: 1-11044765-1\ncolumns: School, Mascot, Location, League, Enrollment\nQ: What city and state are the miners located in?\nA: SELECT Location FROM 1-11044765-1 WHERE Mascot = 'Miners'"}
{"text": "table: 1-11044765-1\ncolumns: School, Mascot, Location, League, Enrollment\nQ: Which school has the Raiders as their mascot?\nA: SELECT School FROM 1-11044765-1 WHERE Mascot = 'Raiders'"}
{"text": "table: 1-1121352-2\ncolumns: No., Date, Tournament, Winning score, To par, Margin of victory, Runner(s)-up\nQ: Where was the tournament dated nov 3, 2002?\nA: SELECT Tournament FROM 1-1121352-2 WHERE Date = 'Nov 3, 2002'"}
{"text": "table: 1-1121352-2\ncolumns: No., Date, Tournament, Winning score, To par, Margin of victory, Runner(s)-up\nQ: Where is the margin of victory dated mar 28, 2004?\nA: SELECT Margin of victory FROM 1-1121352-2 WHERE Date = 'Mar 28, 2004'"}
{"text": "table: 1-1121352-2\ncolumns: No., Date, Tournament, Winning score, To par, Margin of victory, Runner(s)-up\nQ: What is the to par dated may 4, 2003?\nA: SELECT To par FROM 1-1121352-2 WHERE Date = 'May 4, 2003'"}
{"text": "table: 1-1121352-2\ncolumns: No., Date, Tournament, Winning score, To par, Margin of victory, Runner(s)-up\nQ: What date were the runner ups pat hurst juli inkster?\nA: SELECT Date FROM 1-1121352-2 WHERE Runner(s)-up = 'Pat Hurst Juli Inkster'"}
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what's the total number of\u00a0final epbeingode count\u00a0with\u00a0character\u00a0being rick stetler\nA: SELECT COUNT Final Episode Count FROM 1-11210576-4 WHERE Character = 'Rick Stetler'"}
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what are all the actor where first episode is \"ambush\"\nA: SELECT Actor FROM 1-11210576-4 WHERE First Episode = '\"Ambush\"'"}
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what's the\u00a0character\u00a0with\u00a0fate\u00a0being deceased: knife wound\nA: SELECT Character FROM 1-11210576-4 WHERE Fate = 'Deceased: Knife Wound'"}
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what's the total number of\u00a0final epbeingode count\u00a0with\u00a0first epbeingode\u00a0being \"l.a.\"\nA: SELECT COUNT Final Episode Count FROM 1-11210576-4 WHERE First Episode = '\"L.A.\"'"}
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what's the\u00a0actor\u00a0with\u00a0character\u00a0being judge joseph ratner\nA: SELECT Actor FROM 1-11210576-4 WHERE Character = 'Judge Joseph Ratner'"}
{"text": "table: 1-11210576-4\ncolumns: Character, Fate, Actor, First Episode, Final Episode, Duration, Final Episode Count\nQ: what's the\u00a0first epbeingode\u00a0with\u00a0final epbeingode\u00a0being \"rio\"\nA: SELECT First Episode FROM 1-11210576-4 WHERE Final Episode = '\"Rio\"'"}
{"text": "table: 1-11214772-2\ncolumns: Year, Champion, Score, Runner-Up, Location, Semi-Finalist #1, Semi-Finalist #2\nQ: Which team was the second semi finalist in 2007?\nA: SELECT Semi-Finalist #2 FROM 1-11214772-2 WHERE Year = 2007"}
{"text": "table: 1-11214772-2\ncolumns: Year, Champion, Score, Runner-Up, Location, Semi-Finalist #1, Semi-Finalist #2\nQ: How many teams were listed as runner up in 2005 and there the first semi finalist was Western Carolina?\nA: SELECT COUNT Runner-Up FROM 1-11214772-2 WHERE Semi-Finalist #1 = 'Western Carolina' AND Year = 2005"}

119
Data/lora/wikisql.py Normal file
View File

@@ -0,0 +1,119 @@
# Copyright © 2023 Apple Inc.
"""
Code to preprocess the WikiSQL dataset adapted from
https://github.com/salesforce/WikiSQL and
https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb .
"""
import json
import os
def load():
"""
Load all three splits of the WikiSQL dataset.
"""
return (WikiSQL(dn) for dn in ["train", "dev", "test"])
class WikiSQL:
def __init__(self, dataset, save_dir="/tmp"):
valid_sets = ("train", "dev", "test")
if dataset not in valid_sets:
raise ValueError(f"Dataset must be in {valid_sets}, got {dataset}")
data_dir = os.path.join(save_dir, "wikisql")
self._maybe_download(data_dir)
self._parse_tables(os.path.join(data_dir, f"data/{dataset}.tables.jsonl"))
self._parse_queries(os.path.join(data_dir, f"data/{dataset}.jsonl"))
def _maybe_download(self, data_dir):
if not os.path.exists(data_dir):
import io
import tarfile
from urllib import request
url = "https://raw.githubusercontent.com/salesforce/WikiSQL/master/data.tar.bz2"
r = request.urlopen(url)
with tarfile.open(fileobj=io.BytesIO(r.read())) as tf:
tf.extractall(data_dir)
def _parse_tables(self, tables):
self._tables = {}
with open(tables) as f:
for line in f:
table = json.loads(line)
self._tables[table["id"]] = {
"columns": table["header"],
"types": table["types"],
"desc": f"table: {table['id']}\ncolumns: {', '.join(table['header'])}",
}
def _parse_queries(self, queries):
self._queries = []
with open(queries) as f:
for line in f:
query = json.loads(line)
table = self._tables[query["table_id"]]
question = query["question"]
answer = self.query_to_text(
query["sql"], query["table_id"], table["columns"], table["types"]
)
self._queries.append(
f"<s>{table['desc']}\nQ: {question}\nA: {answer}</s>"
)
def query_to_text(self, query, table, columns, types):
aggregation_ops = ["", "MAX", "MIN", "COUNT", "SUM", "AVG"]
condition_ops = ["=", ">", "<", "OP"]
column = columns[query["sel"]]
aggregation = (aggregation_ops[query["agg"]] + " ") if query["agg"] > 0 else ""
sql = f"SELECT {aggregation}{column} FROM {table}"
conditions = query["conds"]
if conditions:
cs = []
for i, o, v in conditions:
column = columns[i]
op = condition_ops[o]
if types[i] == "text":
value = f"'{v}'"
else:
value = v
cs.append(f"{column} {op} {value}")
sql += " WHERE " + " AND ".join(cs)
return sql
def __getitem__(self, idx):
return self._queries[idx]
def __len__(self):
return len(self._queries)
if __name__ == "__main__":
datanames = ["train", "dev", "test"]
sizes = [56355, 8421, 15878]
for dataname, size in zip(datanames, sizes):
len(WikiSQL(dataname)) == size, f"Wrong {dataname} set size."
# Write the sets to jsonl
import json
train, dev, test = load()
datasets = [
(train, "train", 1000),
(dev, "valid", 100),
(test, "test", 100),
]
for dataset, name, size in datasets:
with open(f"data/{name}.jsonl", "w") as fid:
for e, t in zip(range(size), dataset):
# Strip the <s>, </s> since the tokenizer adds them
json.dump({"text": t[3:-4]}, fid)
fid.write("\n")

View File

@@ -236,3 +236,11 @@ public struct CohereConfiguration: Codable {
Float.self, forKey: CohereConfiguration.CodingKeys.logitScale)
}
}
// MARK: - LoRA
extension CohereModel: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}

View File

@@ -254,3 +254,11 @@ public struct GemmaConfiguration: Codable {
Bool.self, forKey: CodingKeys.ropeTraditional) ?? false
}
}
// MARK: - LoRA
extension GemmaModel: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}

View File

@@ -253,3 +253,11 @@ public struct LlamaConfiguration: Codable {
}
}
// MARK: - LoRA
extension LlamaModel: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}

View File

@@ -17,47 +17,64 @@ public func load(
hub: HubApi = HubApi(), configuration: ModelConfiguration,
progressHandler: @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) {
// note: this doesn't have a way to pass the HubApi
let tokenizer = try await loadTokenizer(configuration: configuration)
do {
let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub)
// download the model weights and config
let repo = Hub.Repo(id: configuration.id)
let modelFiles = ["config.json", "*.safetensors"]
let modelDirectory = try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)
let modelDirectory: URL
// create the model (no weights loaded)
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
switch configuration.id {
case .id(let id):
// download the model weights and config
let repo = Hub.Repo(id: id)
let modelFiles = ["config.json", "*.safetensors"]
modelDirectory = try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
case .directory(let directory):
modelDirectory = directory
}
// load the weights
var weights = [String: MLXArray]()
let enumerator = FileManager.default.enumerator(
at: modelDirectory, includingPropertiesForKeys: nil)!
for case let url as URL in enumerator {
if url.pathExtension == "safetensors" {
let w = try loadArrays(url: url)
for (key, value) in w {
weights[key] = value
// create the model (no weights loaded)
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
let model = try baseConfig.modelType.createModel(configuration: configurationURL)
// load the weights
var weights = [String: MLXArray]()
let enumerator = FileManager.default.enumerator(
at: modelDirectory, includingPropertiesForKeys: nil)!
for case let url as URL in enumerator {
if url.pathExtension == "safetensors" {
let w = try loadArrays(url: url)
for (key, value) in w {
weights[key] = value
}
}
}
// quantize if needed
if let quantization = baseConfig.quantization {
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
}
// apply the loaded weights
let parameters = ModuleParameters.unflattened(weights)
try model.update(parameters: parameters, verify: [.all])
eval(model)
return (model, tokenizer)
} catch Hub.HubClientError.authorizationRequired {
// an authorizationRequired means (typically) that the named repo doesn't exist on
// on the server so retry with local only configuration
var newConfiguration = configuration
newConfiguration.id = .directory(configuration.modelDirectory(hub: hub))
return try await load(
hub: hub, configuration: newConfiguration, progressHandler: progressHandler)
}
// quantize if needed
if let quantization = baseConfig.quantization {
quantizeIfNeeded(model: model, weights: weights, quantization: quantization)
}
// apply the loaded weights
let parameters = ModuleParameters.unflattened(weights)
try model.update(parameters: parameters, verify: [.all])
eval(model)
return (model, tokenizer)
}
// MARK: - Quantization

View File

@@ -0,0 +1,61 @@
// Copyright © 2024 Apple Inc.
import Foundation
enum LoRADataError: Error {
case fileNotFound(URL, String)
}
/// Load a LoRA data file.
///
/// Given a directory and a base name, e.g. `train`, this will load a `.jsonl` or `.txt` file
/// if possible.
public func loadLoRAData(directory: URL, name: String) throws -> [String] {
let extensions = ["jsonl", "txt"]
for ext in extensions {
let url = directory.appending(component: "\(name).\(ext)")
if FileManager.default.fileExists(atPath: url.path()) {
return try loadLoRAData(url: url)
}
}
throw LoRADataError.fileNotFound(directory, name)
}
/// Load a .txt or .jsonl file and return the contents
public func loadLoRAData(url: URL) throws -> [String] {
switch url.pathExtension {
case "jsonl":
return try loadJSONL(url: url)
case "txt":
return try loadLines(url: url)
default:
fatalError("Unable to load data file, unknown type: \(url)")
}
}
func loadJSONL(url: URL) throws -> [String] {
struct Line: Codable {
let text: String?
}
return try String(contentsOf: url)
.components(separatedBy: .newlines)
.filter {
$0.first == "{"
}
.compactMap {
try JSONDecoder().decode(Line.self, from: $0.data(using: .utf8)!).text
}
}
func loadLines(url: URL) throws -> [String] {
try String(contentsOf: url)
.components(separatedBy: .newlines)
.filter { !$0.isEmpty }
}

639
Libraries/LLM/Lora.swift Normal file
View File

@@ -0,0 +1,639 @@
// Copyright © 2024 Apple Inc.
import Foundation
import MLX
import MLXNN
import MLXOptimizers
import MLXRandom
import Tokenizers
/// Layers to apply LoRA adapters to.
///
/// This is the value returned by ``LoRAModel/loraLinearLayers()``.
public typealias LoRALinearLayers = [(Module, [String])]
public protocol LoRAModel {
/// Return the layers and keys to apply LoRA adapters to.
///
/// For example this might apply the adapters to the `q` an `v` projections in the
/// Attention layers:
///
/// ```swift
/// model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
/// ```
///
/// It is not required that a model implement this protocol to have LoRA adapters applied, but
/// the command line driver example uses this to produce the ``LoRALinearLayers``.
///
/// ### See Also
/// - ``LoRATrain/convert(model:layers:)``
func loraLinearLayers() -> LoRALinearLayers
}
/// Protocol for LoRA implementations that provides a method for converting back to a `Linear`
/// (or subtype).
///
/// This is normally called via ``LoRATrain/fuse(model:layers:deQuantize:)``
public protocol LoRAConvertToLinear {
func toLinear(deQuantize: Bool) -> Linear
}
/// Implementation of LoRA `Linear` replacement layer.
///
/// This layer implements the LoRA capabilities for `Linear` layers, specifically:
///
/// - converting `Linear` or `QuantizedLinear` layers to ``LoRALinear`` / ``QLoRALinear``
/// - converting ``LoRALinear`` back to `Linear` or `QuantizedLinear` (``LoRAConvertToLinear``)
/// - implementing the LoRA evaluation
///
/// ``QLoRALinear`` is the equivalent class for `QuantizedLinear`.
///
/// This is not typically used directly -- ``LoRATrain/convert(model:layers:)`` is used to
/// add the adapter layers to a given model.
///
/// ### See Also
/// - [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)
/// - [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
/// - ``QLoRALinear``
/// - ``LoRATrain/convert(model:layers:)``
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
public class LoRALinear: Linear, LoRAConvertToLinear {
let scale: Float
@ParameterInfo(key: "lora_a") var loraA: MLXArray
@ParameterInfo(key: "lora_b") var loraB: MLXArray
required public init(
_ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false,
scale: Float = 20.0, linear: Linear
) {
// Scale for low-rank update
self.scale = scale
// Low rank lora weights
let loraScale = 1 / sqrt(Float(inputDimensions))
self._loraA.wrappedValue = MLXRandom.uniform(
low: -loraScale, high: loraScale, [inputDimensions, rank])
self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions])
super.init(weight: linear.weight, bias: linear.bias)
freeze()
}
/// Freeze all parameters except the lora parameters
public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false)
throws
{
// realize the keys and omit the lora parameters
let keys =
(keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 })
.filter {
$0 != "lora_a" && $0 != "lora_b"
}
try super.freeze(recursive: recursive, keys: keys, strict: strict)
}
/// Convert a `Linear` or `QuantizedLinear` layer into a new `Linear` layer
/// that implements the `LoRA` adapter.
///
/// This is typically called via ``LoRATrain/convert(model:layers:)``.
///
/// ### See Also
/// - ``LoRATrain/convert(model:layers:)``
/// - ``QLoRALinear/from(linear:rank:)``
public static func from(linear: Linear, rank: Int = 8) -> Linear {
if let linear = linear as? QuantizedLinear {
return QLoRALinear.from(linear: linear, rank: rank)
}
let (outputDimensions, inputDimensions) = linear.shape
return LoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear)
}
/// Convert back into a fused `Linear` layer.
///
/// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``.
///
/// ### See Also
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
/// - ``LoRAConvertToLinear``
/// - ``QLoRALinear/toLinear(deQuantize:)``
public func toLinear(deQuantize: Bool = false) -> Linear {
let dtype = weight.dtype
let loraB = (scale * loraB.T).asType(dtype)
let loraA = loraA.T.asType(dtype)
return Linear(weight: weight + matmul(loraB, loraA), bias: bias)
}
public override func callAsFunction(_ x: MLXArray) -> MLXArray {
let y = super.callAsFunction(x.asType(weight.dtype))
let z = matmul(matmul(x, self.loraA), self.loraB)
return y + scale * z
}
}
/// Implementation of LoRA `QuantizedLinear` replacement layer.
///
/// See ``LoRALinear`` (equivalent class for `Linear` layers) for more information.
public class QLoRALinear: QuantizedLinear, LoRAConvertToLinear {
let scale: Float
@ParameterInfo(key: "lora_a") var loraA: MLXArray
@ParameterInfo(key: "lora_b") var loraB: MLXArray
required public init(
_ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false,
scale: Float = 20.0, linear: QuantizedLinear
) {
// Scale for low-rank update
self.scale = scale
// Low rank lora weights
let loraScale = 1 / sqrt(Float(inputDimensions))
self._loraA.wrappedValue = MLXRandom.uniform(
low: -loraScale, high: loraScale, [inputDimensions, rank])
self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions])
super.init(
weight: linear.weight, bias: linear.bias, scales: linear.scales, biases: linear.biases,
groupSize: linear.groupSize, bits: linear.bits)
// start frozen except for the lora keys
freeze()
}
/// Freeze all parameters except the lora parameters
public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false)
throws
{
// realize the keys and omit the lora parameters
let keys =
(keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 })
.filter {
$0 != "lora_a" && $0 != "lora_b"
}
try super.freeze(recursive: recursive, keys: keys, strict: strict)
}
/// Convert a `QuantizedLinear` layer into a new `Linear` layer
/// that implements the `LoRA` adapter.
///
/// This is typically called via ``LoRATrain/convert(model:layers:)``.
///
/// ### See Also
/// - ``LoRATrain/convert(model:layers:)``
/// - ``LoRALinear/from(linear:rank:)``
public static func from(linear: QuantizedLinear, rank: Int = 8) -> Linear {
var (outputDimensions, inputDimensions) = linear.shape
inputDimensions = inputDimensions * 32 / linear.bits
return QLoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear)
}
/// Convert back into a fused `QuantizedLinear` layer.
///
/// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``.
///
/// ### See Also
/// - ``LoRATrain/fuse(model:layers:deQuantize:)``
public func toLinear(deQuantize: Bool = false) -> Linear {
// convert back into full weights
let weight = dequantized(
weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits)
let loraB = (scale * loraB.T).asType(.float16)
let loraA = loraA.T.asType(.float16)
// convert back into quantized
return QuantizedLinear(
weight: weight + matmul(loraB, loraA), bias: bias, groupSize: groupSize, bits: bits)
}
public override func callAsFunction(_ x: MLXArray) -> MLXArray {
let y = super.callAsFunction(x.asType(scales.dtype))
let z = matmul(matmul(x, self.loraA), self.loraB)
return y + scale * z
}
}
/// Equivalent to `lora.py/iterate_batches()`. Used internally by ``LoRATrain``.
struct LoRABatchIterator: Sequence, IteratorProtocol {
let dataset: [String]
let batchSize: Int
let tokenizer: Tokenizer
let train: Bool
var indices: [Int]
var index = 0
public init(dataset: [String], tokenizer: Tokenizer, batchSize: Int, train: Bool) {
self.dataset = dataset
self.batchSize = batchSize
self.tokenizer = tokenizer
self.train = train
self.indices = Array(0 ..< dataset.count)
if train {
indices.shuffle()
}
}
mutating public func next() -> (MLXArray, MLXArray, MLXArray)? {
if index >= indices.count {
if !train {
return nil
}
indices.shuffle()
index = 0
}
let endIndex = Swift.min(index + batchSize, indices.count)
let batch = (index ..< endIndex)
.map { tokenizer.encode(text: dataset[indices[$0]]) }
let lengths = batch.map { $0.count }
let maxLength = lengths.max() ?? 0
if maxLength > 2048 {
print(
"""
[WARNING] Some sequences are longer than 2048 tokens.
Consider pre-splitting your data to save memory.
""")
}
// pad to the max length
let batchArray = MLXArray.zeros([lengths.count, maxLength], type: Int32.self)
for (j, (b, l)) in zip(batch, lengths).enumerated() {
batchArray[j, 0 ..< l] = MLXArray(b)
}
index = endIndex
return (batchArray[0..., .stride(to: -1)], batchArray[0..., 1...], MLXArray(lengths))
}
}
/// Collection of functions for adding LoRA adapters to an LLM model, training, fusing and saving/loading weights.
///
/// The typical flow for training is:
///
/// ```swift
/// // load the base model and tokenizer
/// let (model, tokenizer) = try await LLM.load(configuration: ModelConfiguration.mistral7B4bit)
///
/// // add LoRALinear adapter layers
/// LoRATrain.convert(model: model, layers: Array(model.loraLinearLayers().suffix(4)))
///
/// // optionally load LoRA weights
/// try LoRATrain.loadLoRAWeights(model: model, url: ...)
///
/// // load the train/validation data
/// let train = try loadLoRAData(directory: data, name: "train")
/// let valid = try loadLoRAData(directory: data, name: "valid")
///
/// // train
/// let optimizer = Adam(learningRate: 1e-5)
/// try await LoRATrain.train(
/// model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
/// parameters: LoRATrain.Parameters()
/// ) { progress in
/// print(progress)
/// return .more
/// }
/// ```
///
/// At this point the model will be trained and you could do one of the following:
///
/// - ``saveLoRAWeights(model:url:)`` -- write the LoRA weights to a file
/// - ``fuse(model:layers:deQuantize:)`` -- fuse the LoRA weights and convert back into the original model
/// architecture. These weights can be saved and reloaded with normal model handling code.
/// - ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``-- compute the test loss
/// againts a test dataset
/// - use the in memory model as a normal `LLMModel` and evaluate a prompt
///
public enum LoRATrain {
public typealias LoraLossFunction = (Module, MLXArray, MLXArray, MLXArray) -> (
MLXArray, MLXArray
)
/// LoRA training parameters
public struct Parameters {
/// number of prompts to evaluate per iteration
public var batchSize = 4
/// number of iterations to train for
public var iterations = 1000
/// number of training steps between loss reporting
public var stepsPerReport = 10
/// number of steps between validations
public var stepsPerEval = 100
/// number of validations batches, `0` uses the entire validation set
public var validationBatches = 10
/// save the model every N iterations
public var saveEvery = 100
/// save path for the adapter `.safetensors`
public var adapterURL: URL?
public init(
batchSize: Int = 4, iterations: Int = 1000, stepsPerReport: Int = 10,
stepsPerEval: Int = 100, validationBatches: Int = 10, saveEvery: Int = 100,
adapterURL: URL? = nil
) {
self.batchSize = batchSize
self.iterations = iterations
self.stepsPerReport = stepsPerReport
self.stepsPerEval = stepsPerEval
self.validationBatches = validationBatches
self.saveEvery = saveEvery
self.adapterURL = adapterURL
}
}
/// Freeze the model layers and replace the indicated modules (Linear) that should be
/// converted to ``LoRALinear`` and remain trainable.
///
/// Once a model has had the LoRA adapters applied, adapter weights can be loaded
/// (if available):
///
/// ```swift
/// try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
/// ```
///
/// At this point the model is ready for one or more of the following:
///
/// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)``
/// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``
/// - fusing with ``fuse(model:layers:deQuantize:)``
/// - text generation with ``generate(promptTokens:parameters:model:tokenizer:didGenerate:)``
/// - note that this is just using normal model text generation
///
/// - Parameters:
/// - model: model to convert
/// - layers: number of suffix layers to convert
public static func convert(model: Module, layers: LoRALinearLayers) {
model.freeze()
for (layer, keys) in layers {
var update = ModuleChildren()
let children = layer.children()
for key in keys {
if let item = children[key], case .value(let child) = item {
if let linear = child as? Linear {
update[key] = .value(LoRALinear.from(linear: linear))
} else {
print("\(key) on \(layer) is not Linear")
}
} else {
print("failed to find key \(key) on \(layer)")
}
}
layer.update(modules: update)
}
}
/// Fuses the LoRA adapters back into the model weights.
///
/// This produces a model in the original format with `Linear` or `QuantizedLinear` layer
/// weights that incorporate the LoRA adapter.
///
/// - Parameters:
/// - model: model to convert
/// - deQuantize: if `true` will convert `QuantizedLinear` back into `Linear`
public static func fuse(model: Module, layers: LoRALinearLayers, deQuantize: Bool = false) {
for (layer, keys) in layers {
var update = ModuleChildren()
let children = layer.children()
for key in keys {
if let item = children[key], case .value(let child) = item {
if let lora = child as? LoRAConvertToLinear {
update[key] = .value(lora.toLinear(deQuantize: deQuantize))
}
}
}
if !update.isEmpty {
layer.update(modules: update)
}
}
}
public static func loss(model: Module, inputs: MLXArray, targets: MLXArray, lengths: MLXArray)
-> (
MLXArray, MLXArray
)
{
// def loss(model, inputs, targets, lengths):
// run model on inputs
let model = model as! LLMModel
let logits = model(inputs, cache: nil).0.asType(.float32)
// mask padding tokens
let lengthMask = MLXArray(0 ..< inputs.dim(1))[.newAxis, 0...] .< lengths[0..., .newAxis]
// calculate the loss
let ntoks = lengthMask.sum()
let ce = (crossEntropy(logits: logits, targets: targets) * lengthMask).sum() / ntoks
return (ce, ntoks)
}
/// Evaluate the model and dataset and return the loss over the entire dataset.
///
/// - Parameters:
/// - model: the model to evaluate
/// - dataset: the dataset
/// - loss: loss function
/// - tokenizer: tokenizer
/// - batchSize: number of items from the dataset to evaluate at once
/// - batchCount: number of batch elements to evaluate, 0 for all
/// - Returns: the loss over the enumerate data
///
/// ### See Also
/// - ``loadLoRAData(directory:name:)``
public static func evaluate(
model: Module, dataset: [String], loss: LoraLossFunction = loss, tokenizer: Tokenizer,
batchSize: Int, batchCount: Int
) -> Float {
var allLosses = [Float]()
var tokenCount = 0
for (iteration, (inputs, targets, lengths)) in LoRABatchIterator(
dataset: dataset, tokenizer: tokenizer, batchSize: batchSize, train: false
).enumerated() {
let (losses, tokens) = loss(model, inputs, targets, lengths)
allLosses.append((losses * tokens).item(Float.self))
tokenCount += tokens.item(Int.self)
if batchCount != 0 && iteration + 1 >= batchCount {
break
}
}
return (sum(MLXArray(allLosses), stream: .cpu) / tokenCount).item(Float.self)
}
/// Given a model with LoRA adaptors applied, load adapter weights from a `.safetensors` file.
///
/// ### See Also
/// - ``convert(model:layers:)``
/// - ``saveLoRAWeights(model:url:)``
public static func loadLoRAWeights(model: Module, url: URL) throws {
let weights = try ModuleParameters.unflattened(loadArrays(url: url))
try model.update(parameters: weights, verify: .noUnusedKeys)
eval(model)
}
/// Given a model with LoRA adaptors applied, write adapter weights to a `.safetensors` file.
///
/// ### See Also
/// - ``convert(model:layers:)``
/// - ``loadLoRAWeights(model:url:)``
public static func saveLoRAWeights(model: Module, url: URL) throws {
let parameters = Dictionary(
uniqueKeysWithValues: model.trainableParameters().flattened())
try save(arrays: parameters, url: url)
}
public enum Progress: CustomStringConvertible {
case train(
iteration: Int, trainingLoss: Float, iterationsPerSecond: Double,
tokensPerSecond: Double)
case validation(iteration: Int, validationLoss: Float, validationTime: Double)
case save(iteration: Int, url: URL)
public var description: String {
switch self {
case .train(
let iteration, let trainingLoss, let iterationsPerSecond, let tokensPerSecond):
"Iteration \(iteration + 1): training loss \(trainingLoss.formatted()), "
+ "iterations/sec \(iterationsPerSecond.formatted()), "
+ "Tokens/sec \(tokensPerSecond.formatted())"
case .validation(let iteration, let validationLoss, let validationTime):
"Iteration \(iteration + 1): "
+ "validation loss \(validationLoss.formatted()), "
+ "validation time \(validationTime.formatted())s"
case .save(let iteration, let url):
"Iteration \(iteration + 1): saved weights to \(url.path())"
}
}
}
public enum ProgressDisposition {
case stop
case more
}
/// Train (or continue training) LoRA weights.
///
/// - Parameters:
/// - model: model to train
/// - train: training dataset
/// - validate: validate dataset
/// - optimizer: optimizer used in training
/// - loss: loss function
/// - tokenizer: tokenizer
/// - parameters: training parameters
/// - progress: progress callback
public static func train(
model: Module, train: [String], validate: [String], optimizer: Optimizer,
loss: @escaping LoraLossFunction = loss, tokenizer: Tokenizer, parameters: Parameters,
progress: (Progress) async -> ProgressDisposition
) async throws {
// def train(model, train_set, val_set, optimizer, loss, tokenizer, args)
let lossValueGrad = valueAndGrad(model: model) { model, arrays in
let (ce, ntoks) = loss(model, arrays[0], arrays[1], arrays[2])
return [ce, ntoks]
}
var losses = [Float]()
var tokenCount = 0
var start = Date.timeIntervalSinceReferenceDate
for (iteration, (inputs, targets, lengths)) in LoRABatchIterator(
dataset: train, tokenizer: tokenizer, batchSize: parameters.batchSize, train: true
).enumerated() {
// forward and backward pass
let (resultArray, grad) = lossValueGrad(model, [inputs, targets, lengths])
let lvalue = resultArray[0]
let tokens = resultArray[1]
// model update
optimizer.update(model: model, gradients: grad)
eval(model, optimizer, lvalue)
// record loss
losses.append(lvalue.item(Float.self))
tokenCount += tokens.item(Int.self)
// report training loss
if (iteration + 1) % parameters.stepsPerReport == 0 {
let trainingLoss = MLXArray(losses).mean(stream: .cpu).item(Float.self)
let now = Date.timeIntervalSinceReferenceDate
let iterationsPerSecond = Double(parameters.stepsPerReport) / (now - start)
let tokensPerSecond = Double(tokenCount) / (now - start)
if await progress(
.train(
iteration: iteration, trainingLoss: trainingLoss,
iterationsPerSecond: iterationsPerSecond, tokensPerSecond: tokensPerSecond))
== .stop
{
break
}
losses.removeAll()
tokenCount = 0
start = Date.timeIntervalSinceReferenceDate
}
// report validation loss
if iteration == 0 || (iteration + 1) % parameters.stepsPerEval == 0 {
let validationStart = Date.timeIntervalSinceReferenceDate
let validationLoss = evaluate(
model: model, dataset: validate, loss: loss, tokenizer: tokenizer,
batchSize: parameters.batchSize, batchCount: parameters.validationBatches)
let now = Date.timeIntervalSinceReferenceDate
if await progress(
.validation(
iteration: iteration, validationLoss: validationLoss,
validationTime: now - validationStart)) == .stop
{
break
}
start = Date.timeIntervalSinceReferenceDate
}
// save adapter weights if needed
if let adapterURL = parameters.adapterURL, (iteration + 1) % parameters.saveEvery == 0 {
try saveLoRAWeights(model: model, url: adapterURL)
if await progress(.save(iteration: iteration, url: adapterURL)) == .stop {
break
}
start = Date.timeIntervalSinceReferenceDate
}
if iteration + 1 >= parameters.iterations {
break
}
}
}
}

View File

@@ -1,6 +1,7 @@
// Copyright © 2024 Apple Inc.
import Foundation
import Hub
/// Registry of models and and any overrides that go with them, e.g. prompt augmentation.
/// If asked for an unknown configuration this will use the model/tokenizer as-is.
@@ -9,7 +10,22 @@ import Foundation
/// swift-tokenizers code handles a good chunk of that and this is a place to augment that
/// implementation, if needed.
public struct ModelConfiguration {
public let id: String
public enum Identifier {
case id(String)
case directory(URL)
}
public var id: Identifier
public var name: String {
switch id {
case .id(let string):
string
case .directory(let url):
url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent
}
}
/// pull the tokenizer from an alternate id
public let tokenizerId: String?
@@ -26,7 +42,17 @@ public struct ModelConfiguration {
id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
preparePrompt: ((String) -> String)? = nil
) {
self.id = id
self.id = .id(id)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.preparePrompt = preparePrompt
}
public init(
directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil,
preparePrompt: ((String) -> String)? = nil
) {
self.id = .directory(directory)
self.tokenizerId = tokenizerId
self.overrideTokenizer = overrideTokenizer
self.preparePrompt = preparePrompt
@@ -36,13 +62,25 @@ public struct ModelConfiguration {
preparePrompt?(prompt) ?? prompt
}
public func modelDirectory(hub: HubApi = HubApi()) -> URL {
switch id {
case .id(let id):
// download the model weights and config
let repo = Hub.Repo(id: id)
return hub.localRepoLocation(repo)
case .directory(let directory):
return directory
}
}
public static var registry = [String: ModelConfiguration]()
public static func register(configurations: [ModelConfiguration]) {
bootstrap()
for c in configurations {
registry[c.id] = c
registry[c.name] = c
}
}

View File

@@ -240,3 +240,11 @@ public struct PhiConfiguration: Codable {
}
}
// MARK: - LoRA
extension PhiModel: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.selfAttention, ["q_proj", "v_proj"]) }
}
}

View File

@@ -251,3 +251,11 @@ public struct Qwen2Configuration: Codable {
[String: StringOrNumber].self, forKey: Qwen2Configuration.CodingKeys.ropeScaling)
}
}
// MARK: - LoRA
extension Qwen2Model: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}

View File

@@ -1,4 +1,4 @@
# Llama
# LLM
This is a port of several models from:
@@ -6,7 +6,7 @@ This is a port of several models from:
using the Hugging Face swift transformers package to provide tokenization:
https://github.com/huggingface/swift-transformers
- https://github.com/huggingface/swift-transformers
The [Models.swift](Models.swift) provides minor overrides and customization --
if you require overrides for the tokenizer or prompt customizations they can be
@@ -30,3 +30,12 @@ Currently supported model types are:
See [Configuration.swift](Configuration.swift) for more info.
See [llm-tool](../../Tools/llm-tool)
# LoRA
[Lora.swift](Lora.swift) contains an implementation of LoRA based on this example:
- https://github.com/ml-explore/mlx-examples/tree/main/lora
See [llm-tool/LoraCommands.swift](../../Tools/llm-tool/LoraCommands.swift) for an example of a driver and
[llm-tool](../../Tools/llm-tool) for examples of how to run it.

View File

@@ -254,3 +254,11 @@ public struct Starcoder2Configuration: Codable {
?? true
}
}
// MARK: - LoRA
extension Starcoder2Model: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
}
}

View File

@@ -4,10 +4,20 @@ import Foundation
import Hub
import Tokenizers
public func loadTokenizer(configuration: ModelConfiguration) async throws -> Tokenizer {
public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer
{
// from AutoTokenizer.from() -- this lets us override parts of the configuration
let config = LanguageModelConfigurationFromHub(
modelName: configuration.tokenizerId ?? configuration.id)
let config: LanguageModelConfigurationFromHub
switch configuration.id {
case .id(let id):
config = LanguageModelConfigurationFromHub(
modelName: configuration.tokenizerId ?? id, hubApi: hub)
case .directory(let directory):
config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub)
}
guard var tokenizerConfig = try await config.tokenizerConfig else {
throw LLMError(message: "missing config")
}

View File

@@ -27,6 +27,7 @@ let package = Package(
.product(name: "MLX", package: "mlx-swift"),
.product(name: "MLXFast", package: "mlx-swift"),
.product(name: "MLXNN", package: "mlx-swift"),
.product(name: "MLXOptimizers", package: "mlx-swift"),
.product(name: "MLXRandom", package: "mlx-swift"),
.product(name: "Transformers", package: "swift-transformers"),
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
@@ -43,6 +44,7 @@ let package = Package(
.product(name: "MLX", package: "mlx-swift"),
.product(name: "MLXFast", package: "mlx-swift"),
.product(name: "MLXNN", package: "mlx-swift"),
.product(name: "MLXOptimizers", package: "mlx-swift"),
.product(name: "MLXRandom", package: "mlx-swift"),
.product(name: "Transformers", package: "swift-transformers"),
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),

View File

@@ -0,0 +1,15 @@
// Copyright © 2024 Apple Inc.
import ArgumentParser
import Foundation
/// Extension to allow URL command line arguments.
extension URL: ExpressibleByArgument {
public init?(argument: String) {
if argument.contains("://") {
self.init(string: argument)
} else {
self.init(filePath: argument)
}
}
}

View File

@@ -11,18 +11,26 @@ import Tokenizers
struct LLMTool: AsyncParsableCommand {
static var configuration = CommandConfiguration(
abstract: "Command line tool for generating text and manipulating LLMs",
subcommands: [EvaluateCommand.self],
subcommands: [EvaluateCommand.self, LoRACommand.self],
defaultSubcommand: EvaluateCommand.self)
}
/// Command line arguments for loading a model.
struct ModelArguments: ParsableArguments {
@Option(name: .long, help: "Name of the huggingface model")
@Option(name: .long, help: "Name of the huggingface model or absolute path to directory")
var model: String = "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx"
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
let modelConfiguration = ModelConfiguration.configuration(id: model)
let modelConfiguration: ModelConfiguration
if self.model.hasPrefix("/") {
// path
modelConfiguration = ModelConfiguration(directory: URL(filePath: self.model))
} else {
// identifier
modelConfiguration = ModelConfiguration.configuration(id: model)
}
let (model, tokenizer) = try await LLM.load(configuration: modelConfiguration)
return (model, tokenizer, modelConfiguration)
}
@@ -31,7 +39,11 @@ struct ModelArguments: ParsableArguments {
/// Command line arguments for controlling generation of text.
struct GenerateArguments: ParsableArguments {
@Option(name: .shortAndLong, help: "The message to be processed by the model")
@Option(
name: .shortAndLong,
help:
"The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt"
)
var prompt = "compare python and swift"
@Option(name: .shortAndLong, help: "Maximum number of tokens to generate")
@@ -52,18 +64,32 @@ struct GenerateArguments: ParsableArguments {
@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0
@Flag(name: .shortAndLong, help: "If true only print the generated output")
var quiet = false
var generateParameters: GenerateParameters {
GenerateParameters(
temperature: temperature, topP: topP, repetitionPenalty: repetitionPenalty,
repetitionContextSize: repetitionContextSize)
}
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) -> (String, [Int])
{
func resolvePrompt() throws -> String {
if prompt.hasPrefix("@") {
let names = prompt.split(separator: ",").map { String($0.dropFirst()) }
return try names.map { try String(contentsOfFile: $0) }.joined(separator: "\n")
} else {
return prompt
}
}
func tokenizePrompt(configuration: ModelConfiguration, tokenizer: Tokenizer) throws -> (
String, [Int]
) {
MLXRandom.seed(seed)
let prompt = configuration.prepare(prompt: self.prompt)
let promptTokens = tokenizer.encode(text: prompt)
let prompt = try resolvePrompt()
let preparedPrompt = configuration.prepare(prompt: prompt)
let promptTokens = tokenizer.encode(text: preparedPrompt)
return (prompt, promptTokens)
}
@@ -187,21 +213,27 @@ struct EvaluateCommand: AsyncParsableCommand {
mutating func run() async throws {
let (model, tokenizer, modelConfiguration) = try await memory.start(args.load)
print("Model loaded -> \(modelConfiguration.id)")
if !generate.quiet {
print("Model loaded -> \(modelConfiguration.id)")
}
let (prompt, promptTokens) = generate.tokenizePrompt(
let (prompt, promptTokens) = try generate.tokenizePrompt(
configuration: modelConfiguration, tokenizer: tokenizer)
print("Starting generation ...")
print(prompt, terminator: "")
if !generate.quiet {
print("Starting generation ...")
print(prompt, terminator: "")
}
let result = await generate.generate(
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
print()
print("------")
print(result.summary())
memory.reportMemoryStatistics()
if !generate.quiet {
print("------")
print(result.summary())
memory.reportMemoryStatistics()
}
}
}

View File

@@ -0,0 +1,281 @@
// Copyright © 2024 Apple Inc.
import ArgumentParser
import Foundation
import Hub
import LLM
import MLX
import MLXNN
import MLXOptimizers
import MLXRandom
import Tokenizers
struct LoRACommand: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "lora",
abstract: "LoRA commands",
subcommands: [
LoRATrainCommand.self, LoRAFuseCommand.self, LoRATestCommand.self, LoRAEvalCommand.self,
]
)
}
/// Common arguments for loading a LoRA mdoel with adapter weights
struct LoRAModelArguments: ParsableArguments {
@OptionGroup var args: ModelArguments
@Option(name: .long, help: "Save/load path for the trained adapter weights")
public var adapter: URL = URL(filePath: "adapters.safetensors")
@Option(name: .long, help: "Number of layers to fine-tune")
public var loraLayers = 16
/// Load the model and apply the LoRA adapters.
///
/// This does not load the adapter weights as they may not exist yet.
func load() async throws -> (LLMModel, Tokenizer, ModelConfiguration) {
let (model, tokenizer, modelConfiguration) = try await args.load()
// convert some of the Linear layers to LoRALinear
LoRATrain.convert(model: model, layers: loraLayers(model: model))
return (model, tokenizer, modelConfiguration)
}
func loraLayers(model: Module) -> LoRALinearLayers {
guard let layerProvider = model as? LoRAModel else {
// the layerProvider will indicate which Linear layers need to be replaced
fatalError(
"Model \(type(of: model)) (\(args.model)) must implement the LoRALayerProvider protocol"
)
}
return Array(layerProvider.loraLinearLayers().suffix(loraLayers))
}
func describe(model: Module) {
let totalParameterCount = model.parameters()
.flattenedValues().map { $0.size }.reduce(0, +)
let trainableParameterCount = model.trainableParameters()
.flattenedValues().map { $0.size }.reduce(0, +)
print("Model: \(args.model)")
print("Total parameters: \((totalParameterCount / 1_000_000).formatted())M")
print(
"Trainable parameters: \((Float(trainableParameterCount) / 1_000_000).formatted(.number.precision(.significantDigits(1 ..< 4))))M"
)
}
}
struct LoRATrainCommand: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "train",
abstract: "LoRA training"
)
@OptionGroup var args: LoRAModelArguments
@OptionGroup var memory: MemoryArguments
@Flag(help: "Resume training with the given adapter file")
public var resume = false
@Option(name: .long, help: "Directory with {train, valid, test}.{jsonl,txt} files")
public var data: URL = URL(filePath: "data")
@Option(name: .long, help: "Learning rate for the optimizer")
public var learningRate: Float = 1e-5
@Option(name: .long, help: "Number of dataset items to evaluate per iteration (batch)")
public var batchSize = 4
@Option(name: .long, help: "Number iterations to train for")
public var iterations = 1000
@Option(name: .long, help: "Number of iterations between loss reporting")
public var stepsPerReport = 10
@Option(name: .long, help: "Number of iterations between validations")
public var stepsPerEval = 100
@Option(name: .long, help: "Number of validation batches, 0 uses the entire set")
public var validationBatches = 10
@Option(name: .long, help: "Number of iterations between checkpointing the adapter weights")
public var saveEvery = 100
var parameters: LoRATrain.Parameters {
var p = LoRATrain.Parameters()
p.batchSize = self.batchSize
p.iterations = self.iterations
p.stepsPerReport = self.stepsPerReport
p.stepsPerEval = self.stepsPerEval
p.validationBatches = self.validationBatches
p.saveEvery = self.saveEvery
p.adapterURL = args.adapter
return p
}
@MainActor
mutating func run() async throws {
let (model, tokenizer, _) = try await args.load()
args.describe(model: model)
memory.start()
if resume {
print("Loading pretrained adapters from \(args.adapter.path())")
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
}
// load the train/validation data
let train = try loadLoRAData(directory: data, name: "train")
let valid = try loadLoRAData(directory: data, name: "valid")
if train.isEmpty {
fatalError("Training set is empty: \(data.path()))")
}
if valid.isEmpty {
fatalError("Validation set is empty: \(data.path()))")
}
// train
let optimizer = Adam(learningRate: learningRate)
try await LoRATrain.train(
model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer,
parameters: parameters
) { progress in
print(progress)
return .more
}
try LoRATrain.saveLoRAWeights(model: model, url: args.adapter)
}
}
struct LoRAFuseCommand: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "fuse",
abstract: "Fuse lora adapter weights back in to original model"
)
@OptionGroup var args: LoRAModelArguments
@Flag(name: .long, help: "De-quantize QuantizedLinear layers back into Linear")
var deQuantize = false
@Option(name: .long, help: "Hub ID (mlx-community/mistral-lora) or path (/tmp/mistral-lora)")
var output: String
@MainActor
mutating func run() async throws {
let outputURL: URL
if output.hasPrefix("/") {
outputURL = URL(filePath: output)
} else {
let repo = HubApi.Repo(id: output)
outputURL = HubApi().localRepoLocation(repo)
}
let (model, _, modelConfiguration) = try await args.load()
// load the prepared weights
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
// fuse them back into Linear/QuantizedLinear
LoRATrain.fuse(model: model, layers: args.loraLayers(model: model), deQuantize: deQuantize)
// make the new directory and copy files from source model
try FileManager.default.createDirectory(at: outputURL, withIntermediateDirectories: true)
let inputURL = modelConfiguration.modelDirectory()
let enumerator = FileManager.default.enumerator(
at: inputURL, includingPropertiesForKeys: nil)!
for case let url as URL in enumerator {
// copy everything except the model weights -- we will write out the fused one below
if url.pathExtension == "safetensors" {
continue
}
try FileManager.default.copyItem(
at: url, to: outputURL.appending(component: url.lastPathComponent))
}
// write them back out
let weights = Dictionary(uniqueKeysWithValues: model.parameters().flattened())
try save(arrays: weights, url: outputURL.appending(component: "weights.safetensors"))
print("Fused weights written to \(outputURL.path())")
print("Use with:\n\tllm-tool eval --model \(output)")
}
}
struct LoRATestCommand: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "test",
abstract: "LoRA testing"
)
@OptionGroup var args: LoRAModelArguments
@OptionGroup var memory: MemoryArguments
@Option(name: .long, help: "Directory with {train, valid, test}.{jsonl,txt} files")
public var data: URL = URL(filePath: "data")
@Option(name: .long, help: "Minibatch size")
public var batchSize = 4
@MainActor
mutating func run() async throws {
let (model, tokenizer, _) = try await args.load()
args.describe(model: model)
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
memory.start()
let test = try loadLoRAData(directory: data, name: "test")
let loss = LoRATrain.evaluate(
model: model, dataset: test, tokenizer: tokenizer, batchSize: batchSize, batchCount: 0)
print("Test loss \(loss.formatted()), ppl \(exp(loss).formatted())")
}
}
struct LoRAEvalCommand: AsyncParsableCommand {
static var configuration = CommandConfiguration(
commandName: "eval",
abstract: "LoRA evaluation"
)
@OptionGroup var args: LoRAModelArguments
@OptionGroup var memory: MemoryArguments
@OptionGroup var generate: GenerateArguments
@MainActor
mutating func run() async throws {
let (model, tokenizer, modelConfiguration) = try await args.load()
args.describe(model: model)
try LoRATrain.loadLoRAWeights(model: model, url: args.adapter)
memory.start()
let (prompt, promptTokens) = try generate.tokenizePrompt(
configuration: modelConfiguration, tokenizer: tokenizer)
if !generate.quiet {
print("Starting generation ...")
print(prompt, terminator: "")
}
// generate and print the result
let _ = await generate.generate(
promptTokens: promptTokens, model: model, tokenizer: tokenizer)
print()
}
}

View File

@@ -8,7 +8,7 @@ See various READMEs:
Build the `llm-tool` scheme in Xcode.
### Running (Xcode)
### Running: Xcode
To run this in Xcode simply press cmd-opt-r to set the scheme arguments. For example:
@@ -30,7 +30,7 @@ The model should be a path in the Hugging Face repository, e.g.:
See [LLM](../../Libraries/LLM/README.md) for more info.
### Running (Command Line)
### Running: Command Line
Use the `mlx-run` script to run the command line tools:
@@ -60,3 +60,184 @@ Building in Release / optimizations will remove a lot of tail calls in the C++
layer. These lead to the stack overflows.
See discussion here: https://github.com/ml-explore/mlx-swift-examples/issues/3
## LoRA
`llm-tool` provides an example LoRA driver based on:
- https://github.com/ml-explore/mlx-examples/blob/main/lora/README.md
This is an example of using MLX to fine-tune an LLM with low rank adaptation
(LoRA) for a target task.[^lora] The example also supports quantized LoRA
(QLoRA).[^qlora] The example works with Llama and Mistral style models
available on Hugging Face.
In this example we'll use the WikiSQL[^wikisql] dataset to train the LLM to
generate SQL queries from natural language. However, the example is intended to
be general should you wish to use a custom dataset.
> Note: Some of the prompts have newlines in them which is difficult to achieve via running in Xcode.
Running `llm-tool lora` will produce help:
```
SUBCOMMANDS:
train LoRA training
fuse Fuse lora adapter weights back in to original model
test LoRA testing
eval LoRA evaluation
```
### Training
The first step will be training the LoRA adapter. Example training data
is available in $SRCROOT/Data/lora. You can use your
own data in either `jsonl` or `txt` format with one entry per line.
We need to specify a number of parameters:
- `--model` -- which model to use. This can be quantized [^qlora] or not [^lora]
- `--data` -- directory with the test, train and valid files. These can be either `jsonl` or `txt` files
- `--adapter` -- path to a safetensors file to write the fine tuned parameters into
Additionally the performance of the fine tuning can be controlled with:
- `--batch-size` -- size of the minibatches to run in the training loop, e.g. how many prompts to process per iteration
- `--lora-layers` -- the number of layers in the Attention section of the model to adapt and train
- `--iterations` -- the number of iterations to train for
If desired, the amount of memory used can be adjusted with:
- `--cache-size` -- the number shown below limits the cache size to 1024M
Here is an example run using adapters on the last 4 layers of the model:
```
./mlx-run llm-tool lora train \
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
--data Data/lora \
--adapter /tmp/lora-layers-4.safetensors \
--batch-size 1 --lora-layers 4 \
--cache-size 1024
```
giving output like this:
```
Model: mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
Total parameters: 1,242M
Trainable parameters: 0.426M
Iteration 1: validation loss 2.443872, validation time 3.330629s
Iteration 10: training loss 2.356848, iterations/sec 2.640604, Tokens/sec 260.363581
Iteration 20: training loss 2.063395, iterations/sec 2.294999, Tokens/sec 232.483365
Iteration 30: training loss 1.63846, iterations/sec 2.279401, Tokens/sec 225.204788
Iteration 40: training loss 1.66366, iterations/sec 2.493669, Tokens/sec 218.196057
Iteration 50: training loss 1.470927, iterations/sec 2.301153, Tokens/sec 231.72614
Iteration 60: training loss 1.396581, iterations/sec 2.400012, Tokens/sec 230.401195
Iteration 70: training loss 1.587023, iterations/sec 2.422193, Tokens/sec 218.966258
Iteration 80: training loss 1.376895, iterations/sec 2.111973, Tokens/sec 216.477187
Iteration 90: training loss 1.245127, iterations/sec 2.383802, Tokens/sec 214.065436
Iteration 100: training loss 1.344523, iterations/sec 2.424746, Tokens/sec 223.076649
Iteration 100: validation loss 1.400582, validation time 3.489797s
Iteration 100: saved weights to /tmp/lora.safetensors
...
Iteration 910: training loss 1.181306, iterations/sec 2.355085, Tokens/sec 212.428628
Iteration 920: training loss 1.042286, iterations/sec 2.374377, Tokens/sec 222.479127
Iteration 930: training loss 0.920768, iterations/sec 2.475088, Tokens/sec 220.035347
Iteration 940: training loss 1.140762, iterations/sec 2.119886, Tokens/sec 227.039828
Iteration 950: training loss 1.068073, iterations/sec 2.523047, Tokens/sec 218.495903
Iteration 960: training loss 1.106662, iterations/sec 2.339293, Tokens/sec 221.063186
Iteration 970: training loss 0.833658, iterations/sec 2.474683, Tokens/sec 213.56517
Iteration 980: training loss 0.844026, iterations/sec 2.441064, Tokens/sec 210.663791
Iteration 990: training loss 0.903735, iterations/sec 2.253876, Tokens/sec 218.175162
Iteration 1000: training loss 0.872615, iterations/sec 2.343899, Tokens/sec 219.62336
Iteration 1000: validation loss 0.714194, validation time 3.470462s
Iteration 1000: saved weights to /tmp/lora-layers-4.safetensors
```
### Testing
You can test the LoRA adapated model against the `test` dataset using this command:
```
./mlx-run llm-tool lora test \
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
--data Data/lora \
--adapter /tmp/lora-layers-4.safetensors \
--batch-size 1 --lora-layers 4 \
--cache-size 1024
```
This will run all the items (100 in the example data we are using) in the test set and compute the loss:
```
Model: mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
Total parameters: 1,242M
Trainable parameters: 0.426M
Test loss 1.327623, ppl 3.772065
```
### Evaluate
Next you can evaluate your own prompts with the fine tuned LoRA adapters. It is important to
follow the prompt example from the training data to match the format:
```
{"text": "table: 1-10015132-1\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\nQ: What school did player number 6 come from?\nA: SELECT School/Club Team FROM 1-10015132-1 WHERE No. = '6'"}
```
Given that format you might issue a command like this:
```
./mlx-run llm-tool lora eval \
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
--adapter /tmp/lora-layers-4.safetensors \
--lora-layers 4 \
--prompt "table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality
A: "
```
> Note: the prompt has newlines in it to match the format of the fine tuned prompts -- this may be easier to do with the command line than Xcode.
You might be treated to a response like this:
```
Model: mlx-community/Mistral-7B-v0.1-hf-4bit-mlx
Total parameters: 1,242M
Trainable parameters: 0.426M
Starting generation ...
table: 1-10015132-16
columns: Player, No., Nationality, Position, Years in Toronto, School/Club Team
Q: What is terrence ross' nationality
A: SELECT Nationality FROM 1-10015132-16 WHERE Player = 'Terrence Ross' AND No. = 1
```
### Fusing
Once the adapter weights are trained you can produce new weights with the original achitecture that
have the adapter weights merged in:
```
./mlx-run llm-tool lora fuse \
--model mlx-community/Mistral-7B-v0.1-hf-4bit-mlx \
--adapter /tmp/lora-layers-4.safetensors \
--output mlx-community/mistral-lora
```
outputs:
```
Total parameters: 1,244M
Trainable parameters: 0.426M
Use with:
llm-tool eval --model mlx-community/mistral-lora
```
As noted in the output these new weights can be used with the original model architecture.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
[^wikisql]: Refer to the [GitHub repo](https://github.com/salesforce/WikiSQL/tree/master) for more information about WikiSQL.

View File

@@ -12,6 +12,15 @@
52A776182B94B5EE00AA6E80 /* Qwen2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 52A776172B94B5EE00AA6E80 /* Qwen2.swift */; };
81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = 81695B402BA373D300F260D8 /* MarkdownUI */; };
819BEFF82BAF8B4E0002CCEE /* DeviceStat.swift in Sources */ = {isa = PBXBuildFile; fileRef = 819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */; };
C3056BAE2BCD97B700A31D04 /* LoRATrainingExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3056BAD2BCD97B700A31D04 /* LoRATrainingExampleApp.swift */; };
C3056BB02BCD97B700A31D04 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3056BAF2BCD97B700A31D04 /* ContentView.swift */; };
C3056BB22BCD97B800A31D04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3056BB12BCD97B800A31D04 /* Assets.xcassets */; };
C3056BB62BCD97B800A31D04 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C3056BB52BCD97B800A31D04 /* Preview Assets.xcassets */; };
C3056BBA2BCD981900A31D04 /* train.jsonl in Resources */ = {isa = PBXBuildFile; fileRef = C3056BA22BCD973400A31D04 /* train.jsonl */; };
C3056BBB2BCD981900A31D04 /* test.jsonl in Resources */ = {isa = PBXBuildFile; fileRef = C3056BA12BCD973400A31D04 /* test.jsonl */; };
C3056BBC2BCD981900A31D04 /* valid.jsonl in Resources */ = {isa = PBXBuildFile; fileRef = C3056BA32BCD973400A31D04 /* valid.jsonl */; };
C3056BBD2BCD984F00A31D04 /* LLM.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; };
C3056BBE2BCD984F00A31D04 /* LLM.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C38935C52B869C7A0037B833 /* LLM.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
C3288D762B6D9313009FF608 /* LinearModelTraining.swift in Sources */ = {isa = PBXBuildFile; fileRef = C3288D752B6D9313009FF608 /* LinearModelTraining.swift */; };
C3288D7B2B6D9339009FF608 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C3288D7A2B6D9339009FF608 /* ArgumentParser */; };
C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; };
@@ -22,6 +31,11 @@
C34E49292B6A028100FCB841 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C34E49282B6A028100FCB841 /* ArgumentParser */; };
C34E492A2B6A028800FCB841 /* MNIST.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; };
C34E492B2B6A028800FCB841 /* MNIST.framework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = C34E490D2B69A92900FCB841 /* MNIST.framework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; };
C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */; };
C36BEFB22BBDE9D0002D4AFE /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C36BEFB12BBDE9D0002D4AFE /* MLXOptimizers */; };
C36BEFB52BBDEAD8002D4AFE /* LoraCommands.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFB32BBDEA69002D4AFE /* LoraCommands.swift */; };
C36BEFB82BBDED51002D4AFE /* Arguments.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFB62BBDECBC002D4AFE /* Arguments.swift */; };
C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */ = {isa = PBXBuildFile; fileRef = C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */; };
C382DE8A2B630889000F8F03 /* AsyncAlgorithms in Frameworks */ = {isa = PBXBuildFile; productRef = C382DE892B630889000F8F03 /* AsyncAlgorithms */; };
C38935C82B869C7A0037B833 /* LLM.h in Headers */ = {isa = PBXBuildFile; fileRef = C38935C72B869C7A0037B833 /* LLM.h */; settings = {ATTRIBUTES = (Public, ); }; };
C38935CC2B869C870037B833 /* Llama.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48EE2B696E6500FCB841 /* Llama.swift */; };
@@ -69,6 +83,13 @@
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
C3056BBF2BCD984F00A31D04 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = C39273682B60697700368D5D /* Project object */;
proxyType = 1;
remoteGlobalIDString = C38935C42B869C7A0037B833;
remoteInfo = LLM;
};
C34E492C2B6A028800FCB841 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = C39273682B60697700368D5D /* Project object */;
@@ -100,6 +121,17 @@
/* End PBXContainerItemProxy section */
/* Begin PBXCopyFilesBuildPhase section */
C3056BC12BCD984F00A31D04 /* Embed Frameworks */ = {
isa = PBXCopyFilesBuildPhase;
buildActionMask = 2147483647;
dstPath = "";
dstSubfolderSpec = 10;
files = (
C3056BBE2BCD984F00A31D04 /* LLM.framework in Embed Frameworks */,
);
name = "Embed Frameworks";
runOnlyForDeploymentPostprocessing = 0;
};
C3288D712B6D9313009FF608 /* CopyFiles */ = {
isa = PBXCopyFilesBuildPhase;
buildActionMask = 2147483647;
@@ -187,6 +219,17 @@
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Starcoder2.swift; sourceTree = "<group>"; };
52A776172B94B5EE00AA6E80 /* Qwen2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Qwen2.swift; sourceTree = "<group>"; };
819BEFF62BAF8B4E0002CCEE /* DeviceStat.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DeviceStat.swift; sourceTree = "<group>"; };
C3056BA12BCD973400A31D04 /* test.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = test.jsonl; sourceTree = "<group>"; };
C3056BA22BCD973400A31D04 /* train.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = train.jsonl; sourceTree = "<group>"; };
C3056BA32BCD973400A31D04 /* valid.jsonl */ = {isa = PBXFileReference; lastKnownFileType = text; path = valid.jsonl; sourceTree = "<group>"; };
C3056BA42BCD973400A31D04 /* wikisql.py */ = {isa = PBXFileReference; lastKnownFileType = text.script.python; path = wikisql.py; sourceTree = "<group>"; };
C3056BAB2BCD97B700A31D04 /* LoRATrainingExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = LoRATrainingExample.app; sourceTree = BUILT_PRODUCTS_DIR; };
C3056BAD2BCD97B700A31D04 /* LoRATrainingExampleApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LoRATrainingExampleApp.swift; sourceTree = "<group>"; };
C3056BAF2BCD97B700A31D04 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
C3056BB12BCD97B800A31D04 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
C3056BB32BCD97B800A31D04 /* LoRATrainingExample.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = LoRATrainingExample.entitlements; sourceTree = "<group>"; };
C3056BB52BCD97B800A31D04 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
C3056BC42BCDAB8600A31D04 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
C325DE3F2B648CDB00628871 /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
C3288D732B6D9313009FF608 /* LinearModelTraining */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = LinearModelTraining; sourceTree = BUILT_PRODUCTS_DIR; };
C3288D752B6D9313009FF608 /* LinearModelTraining.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LinearModelTraining.swift; sourceTree = "<group>"; };
@@ -202,6 +245,10 @@
C34E49142B69C1E300FCB841 /* Files.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Files.swift; sourceTree = "<group>"; };
C34E49212B6A026F00FCB841 /* mnist-tool */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = "mnist-tool"; sourceTree = BUILT_PRODUCTS_DIR; };
C34E49232B6A026F00FCB841 /* MNISTTool.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MNISTTool.swift; sourceTree = "<group>"; };
C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Lora.swift; sourceTree = "<group>"; };
C36BEFB32BBDEA69002D4AFE /* LoraCommands.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LoraCommands.swift; sourceTree = "<group>"; };
C36BEFB62BBDECBC002D4AFE /* Arguments.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Arguments.swift; sourceTree = "<group>"; };
C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "Lora+Data.swift"; sourceTree = "<group>"; };
C38935C52B869C7A0037B833 /* LLM.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = LLM.framework; sourceTree = BUILT_PRODUCTS_DIR; };
C38935C72B869C7A0037B833 /* LLM.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LLM.h; sourceTree = "<group>"; };
C38935DE2B869DD00037B833 /* Phi.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Phi.swift; sourceTree = "<group>"; };
@@ -237,6 +284,14 @@
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
C3056BA82BCD97B700A31D04 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
C3056BBD2BCD984F00A31D04 /* LLM.framework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
C3288D702B6D9313009FF608 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
@@ -273,6 +328,7 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
C36BEFB22BBDE9D0002D4AFE /* MLXOptimizers in Frameworks */,
C38935D22B869CC40037B833 /* MLXNN in Frameworks */,
C38935D42B869CC40037B833 /* MLXRandom in Frameworks */,
C38935D62B869CC40037B833 /* Transformers in Frameworks */,
@@ -328,6 +384,46 @@
path = ViewModels;
sourceTree = "<group>";
};
C3056BA52BCD973400A31D04 /* lora */ = {
isa = PBXGroup;
children = (
C3056BA12BCD973400A31D04 /* test.jsonl */,
C3056BA22BCD973400A31D04 /* train.jsonl */,
C3056BA32BCD973400A31D04 /* valid.jsonl */,
C3056BA42BCD973400A31D04 /* wikisql.py */,
);
path = lora;
sourceTree = "<group>";
};
C3056BA62BCD973400A31D04 /* Data */ = {
isa = PBXGroup;
children = (
C3056BA52BCD973400A31D04 /* lora */,
);
path = Data;
sourceTree = "<group>";
};
C3056BAC2BCD97B700A31D04 /* LoRATrainingExample */ = {
isa = PBXGroup;
children = (
C3056BAD2BCD97B700A31D04 /* LoRATrainingExampleApp.swift */,
C3056BAF2BCD97B700A31D04 /* ContentView.swift */,
C3056BB12BCD97B800A31D04 /* Assets.xcassets */,
C3056BB32BCD97B800A31D04 /* LoRATrainingExample.entitlements */,
C3056BB42BCD97B800A31D04 /* Preview Content */,
C3056BC42BCDAB8600A31D04 /* README.md */,
);
path = LoRATrainingExample;
sourceTree = "<group>";
};
C3056BB42BCD97B800A31D04 /* Preview Content */ = {
isa = PBXGroup;
children = (
C3056BB52BCD97B800A31D04 /* Preview Assets.xcassets */,
);
path = "Preview Content";
sourceTree = "<group>";
};
C3288D742B6D9313009FF608 /* LinearModelTraining */ = {
isa = PBXGroup;
children = (
@@ -340,8 +436,10 @@
C34E48F32B696F0B00FCB841 /* llm-tool */ = {
isa = PBXGroup;
children = (
C34E48F42B696F0B00FCB841 /* LLMTool.swift */,
C34E48F92B69930300FCB841 /* README.md */,
C34E48F42B696F0B00FCB841 /* LLMTool.swift */,
C36BEFB32BBDEA69002D4AFE /* LoraCommands.swift */,
C36BEFB62BBDECBC002D4AFE /* Arguments.swift */,
);
path = "llm-tool";
sourceTree = "<group>";
@@ -370,6 +468,8 @@
C38935C62B869C7A0037B833 /* LLM */ = {
isa = PBXGroup;
children = (
C36BEFAF2BBCBAC2002D4AFE /* Lora.swift */,
C36BEFBA2BBF02CC002D4AFE /* Lora+Data.swift */,
525C1E9C2B9A010F00B5C356 /* Starcoder2.swift */,
C34E48EF2B696E6500FCB841 /* Configuration.swift */,
C3A8B3AB2B9283150002EFB8 /* Models.swift */,
@@ -393,6 +493,7 @@
children = (
C325DE3F2B648CDB00628871 /* README.md */,
F8D7023A2BB4E223003D7CF5 /* Package.swift */,
C3056BA62BCD973400A31D04 /* Data */,
C39273822B606A9200368D5D /* Libraries */,
C3A8B3AD2B9294E30002EFB8 /* Applications */,
C39273812B606A7400368D5D /* Tools */,
@@ -412,6 +513,7 @@
C38935C52B869C7A0037B833 /* LLM.framework */,
C3A8B3B22B9295090002EFB8 /* MNISTTrainer.app */,
C3A8B3DC2B92A29E0002EFB8 /* LLMEval.app */,
C3056BAB2BCD97B700A31D04 /* LoRATrainingExample.app */,
);
name = Products;
sourceTree = "<group>";
@@ -454,6 +556,7 @@
C3A8B3AD2B9294E30002EFB8 /* Applications */ = {
isa = PBXGroup;
children = (
C3056BAC2BCD97B700A31D04 /* LoRATrainingExample */,
C3A8B3EB2B92A2A90002EFB8 /* LLMEval */,
C3A8B3C12B92951E0002EFB8 /* MNISTTrainer */,
);
@@ -527,6 +630,27 @@
/* End PBXHeadersBuildPhase section */
/* Begin PBXNativeTarget section */
C3056BAA2BCD97B700A31D04 /* LoRATrainingExample */ = {
isa = PBXNativeTarget;
buildConfigurationList = C3056BB72BCD97B800A31D04 /* Build configuration list for PBXNativeTarget "LoRATrainingExample" */;
buildPhases = (
C3056BA72BCD97B700A31D04 /* Sources */,
C3056BA82BCD97B700A31D04 /* Frameworks */,
C3056BA92BCD97B700A31D04 /* Resources */,
C3056BC12BCD984F00A31D04 /* Embed Frameworks */,
);
buildRules = (
);
dependencies = (
C3056BC02BCD984F00A31D04 /* PBXTargetDependency */,
);
name = LoRATrainingExample;
packageProductDependencies = (
);
productName = LoRATrainingExample;
productReference = C3056BAB2BCD97B700A31D04 /* LoRATrainingExample.app */;
productType = "com.apple.product-type.application";
};
C3288D722B6D9313009FF608 /* LinearModelTraining */ = {
isa = PBXNativeTarget;
buildConfigurationList = C3288D792B6D9313009FF608 /* Build configuration list for PBXNativeTarget "LinearModelTraining" */;
@@ -617,6 +741,7 @@
C38935D32B869CC40037B833 /* MLXRandom */,
C38935D52B869CC40037B833 /* Transformers */,
C38935DC2B869CEC0037B833 /* AsyncAlgorithms */,
C36BEFB12BBDE9D0002D4AFE /* MLXOptimizers */,
);
productName = LLM;
productReference = C38935C52B869C7A0037B833 /* LLM.framework */;
@@ -716,9 +841,12 @@
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastSwiftUpdateCheck = 1520;
LastSwiftUpdateCheck = 1530;
LastUpgradeCheck = 1500;
TargetAttributes = {
C3056BAA2BCD97B700A31D04 = {
CreatedOnToolsVersion = 15.3;
};
C3288D722B6D9313009FF608 = {
CreatedOnToolsVersion = 15.0.1;
};
@@ -775,11 +903,24 @@
C3288D722B6D9313009FF608 /* LinearModelTraining */,
C3A8B3B12B9295090002EFB8 /* MNISTTrainer */,
C3A8B3DB2B92A29D0002EFB8 /* LLMEval */,
C3056BAA2BCD97B700A31D04 /* LoRATrainingExample */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
C3056BA92BCD97B700A31D04 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
C3056BBA2BCD981900A31D04 /* train.jsonl in Resources */,
C3056BBB2BCD981900A31D04 /* test.jsonl in Resources */,
C3056BBC2BCD981900A31D04 /* valid.jsonl in Resources */,
C3056BB62BCD97B800A31D04 /* Preview Assets.xcassets in Resources */,
C3056BB22BCD97B800A31D04 /* Assets.xcassets in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
C34E490B2B69A92900FCB841 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
@@ -815,6 +956,15 @@
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
C3056BA72BCD97B700A31D04 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
C3056BB02BCD97B700A31D04 /* ContentView.swift in Sources */,
C3056BAE2BCD97B700A31D04 /* LoRATrainingExampleApp.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
C3288D6F2B6D9313009FF608 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
@@ -850,6 +1000,8 @@
C38935E32B86C0FE0037B833 /* Gemma.swift in Sources */,
C38935CD2B869C870037B833 /* Configuration.swift in Sources */,
525C1E9D2B9A011000B5C356 /* Starcoder2.swift in Sources */,
C36BEFBB2BBF02CC002D4AFE /* Lora+Data.swift in Sources */,
C36BEFB02BBCBAC2002D4AFE /* Lora.swift in Sources */,
C38935DF2B869DD00037B833 /* Phi.swift in Sources */,
C38935CE2B869C870037B833 /* Load.swift in Sources */,
C3E786AD2B8D4AF50004D037 /* Tokenizer.swift in Sources */,
@@ -872,7 +1024,9 @@
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
C36BEFB82BBDED51002D4AFE /* Arguments.swift in Sources */,
C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */,
C36BEFB52BBDEAD8002D4AFE /* LoraCommands.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
@@ -899,6 +1053,11 @@
/* End PBXSourcesBuildPhase section */
/* Begin PBXTargetDependency section */
C3056BC02BCD984F00A31D04 /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
target = C38935C42B869C7A0037B833 /* LLM */;
targetProxy = C3056BBF2BCD984F00A31D04 /* PBXContainerItemProxy */;
};
C34E492D2B6A028800FCB841 /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
target = C34E490C2B69A92900FCB841 /* MNIST */;
@@ -922,6 +1081,183 @@
/* End PBXTargetDependency section */
/* Begin XCBuildConfiguration section */
C3056BB82BCD97B800A31D04 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
CODE_SIGN_ENTITLEMENTS = Applications/LoRATrainingExample/LoRATrainingExample.entitlements;
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = dwarf;
DEVELOPMENT_ASSET_PATHS = "\"Applications/LoRATrainingExample/Preview Content\"";
DEVELOPMENT_TEAM = "";
ENABLE_PREVIEWS = YES;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
GENERATE_INFOPLIST_FILE = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
IPHONEOS_DEPLOYMENT_TARGET = 17.2;
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MACOSX_DEPLOYMENT_TARGET = 14.2;
MARKETING_VERSION = 1.0;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
PRODUCT_BUNDLE_IDENTIFIER = mlx.LoRATrainingExample;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator";
SUPPORTS_MACCATALYST = NO;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2,7";
};
name = Debug;
};
C3056BB92BCD97B800A31D04 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
CODE_SIGN_ENTITLEMENTS = Applications/LoRATrainingExample/LoRATrainingExample.entitlements;
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 1;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
DEVELOPMENT_ASSET_PATHS = "\"Applications/LoRATrainingExample/Preview Content\"";
DEVELOPMENT_TEAM = "";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_PREVIEWS = YES;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
GENERATE_INFOPLIST_FILE = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES;
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault;
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
IPHONEOS_DEPLOYMENT_TARGET = 17.2;
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MACOSX_DEPLOYMENT_TARGET = 14.2;
MARKETING_VERSION = 1.0;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
PRODUCT_BUNDLE_IDENTIFIER = mlx.LoRATrainingExample;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator";
SUPPORTS_MACCATALYST = NO;
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2,7";
};
name = Release;
};
C3288D772B6D9313009FF608 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
@@ -2140,6 +2476,15 @@
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
C3056BB72BCD97B800A31D04 /* Build configuration list for PBXNativeTarget "LoRATrainingExample" */ = {
isa = XCConfigurationList;
buildConfigurations = (
C3056BB82BCD97B800A31D04 /* Debug */,
C3056BB92BCD97B800A31D04 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
C3288D792B6D9313009FF608 /* Build configuration list for PBXNativeTarget "LinearModelTraining" */ = {
isa = XCConfigurationList;
buildConfigurations = (
@@ -2295,6 +2640,11 @@
package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */;
productName = ArgumentParser;
};
C36BEFB12BBDE9D0002D4AFE /* MLXOptimizers */ = {
isa = XCSwiftPackageProductDependency;
package = C3FBCB1F2B8520B00007E490 /* XCRemoteSwiftPackageReference "mlx-swift" */;
productName = MLXOptimizers;
};
C382DE892B630889000F8F03 /* AsyncAlgorithms */ = {
isa = XCSwiftPackageProductDependency;
package = C382DE882B630889000F8F03 /* XCRemoteSwiftPackageReference "swift-async-algorithms" */;

View File

@@ -16,7 +16,7 @@
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"branch" : "main",
"revision" : "b4d3e4bbbe41e6dc7c46d5ba075049ae7177961b"
"revision" : "cf2c5d20c8575b375cb0d97a06ae0199527b5f32"
}
},
{