Skip to content
Open
2 changes: 2 additions & 0 deletions cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"benchmark_gpt": "npm run build && node dist/benchmark_gpt.js",
"train_gpt": "npm run build && node dist/train_gpt.js",
"hellaswag_gpt": "npm run build && node dist/hellaswag_gpt.js",
"eval_finetuned_gpt2": "npm run build && node dist/evaluate_finetuned_gpt2.js",
"finetune_gpt": "npm run build && node dist/finetune_gpt.js",
"build": "tsc --build",
"test": ": nothing"
},
Expand Down
13 changes: 11 additions & 2 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ export interface BenchmarkArguments {
roundDuration: number
batchSize: number
validationSplit: number
datasetPath?: string
validationDatasetPath?: string

// DP
epsilon?: number
Expand All @@ -36,11 +38,14 @@ export interface BenchmarkArguments {
maxShareValue?: number

save: boolean
saveModel: boolean
host: URL
}

type BenchmarkUnsafeArguments = Omit<BenchmarkArguments, 'provider'> & {
task: string
datasetPath?: string
validationDatasetPath?: string
help?: boolean
}

Expand All @@ -55,7 +60,10 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 },
datasetPath: { type: String, alias: 'd', description: 'Path to the dataset', optional: true },
validationDatasetPath: { type: String, alias: 'V', description: 'Path to the validation dataset', optional: true },
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
saveModel: { type: Boolean, alias: 'm', description: 'Save trained model to disk', defaultValue: false },
host: {
type: (raw: string) => new URL(raw),
typeLabel: "URL",
Expand Down Expand Up @@ -89,18 +97,19 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(

const supportedTasks = Map(
await Promise.all(
Set.of<TaskProvider<"image" | "tabular", Network>>(
Set.of<TaskProvider<"image" | "tabular" | "text", Network>>(
defaultTasks.cifar10,
defaultTasks.lusCovid,
defaultTasks.simpleFace,
defaultTasks.titanic,
defaultTasks.tinderDog,
defaultTasks.mnist,
defaultTasks.privacyrun,
).map(
async (t) =>
[(await t.getTask()).id, t] as [
string,
TaskProvider<"image" | "tabular", Network>,
TaskProvider<"image" | "tabular" | "text", Network>,
],
),
),
Expand Down
49 changes: 41 additions & 8 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import fs from 'node:fs/promises'
import { createWriteStream } from "node:fs";
import path from "node:path";

import createDebug from "debug";
import type {
Dataset,
DataFormat,
Expand All @@ -17,25 +17,42 @@
} from "@epfml/discojs";
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'

import { loadText, saveModelToDisk } from "@epfml/discojs-node";
import { getTaskData } from './data.js'
import { args } from './args.js'
import { makeUserLogFile } from "./user_log.js";
import type { UserLogFile } from "./user_log.js";

const debug = createDebug("cli:main");

async function runUser<D extends DataType, N extends Network>(
task: Task<D, N>,
provider: TaskProvider<D, N>,

Check failure on line 30 in cli/src/cli.ts

View workflow job for this annotation

GitHub Actions / lint-most

'provider' is defined but never used. Allowed unused args must match /^_/u
url: URL,
data: Dataset<DataFormat.Raw[D]>,
validationData: Dataset<DataFormat.Raw[D]> | undefined,
userIndex: number,
numberOfUsers: number,
): Promise<List<SummaryLogs>> {
// cast as typescript isn't good with generics
debug(`Starting runUser for client ${userIndex}`);
const userStart = Date.now();

Check failure on line 38 in cli/src/cli.ts

View workflow job for this annotation

GitHub Actions / lint-most

'userStart' is assigned a value but never used. Allowed unused vars must match /^_/u
const trainingScheme = task.trainingInformation.scheme as N
const aggregator = aggregators.getAggregator(task)
const client = clients.getClient(trainingScheme, url, task, aggregator)
const disco = new Disco(task, client, { scheme: trainingScheme });

const disco = new Disco(task, client, { scheme: trainingScheme, preprocessOnce: true });

// For local training, load model from provider before training starts
// if (trainingScheme === "local") {
// debug(`Loading model for training client ${userIndex}...`);
// const modelStart = Date.now();
// console.log("Loading model for local training...");
// disco.trainer.model = await provider.getModel();
// console.log("Model loaded successfully");
// debug(`Model loading took ${Date.now() - modelStart}ms for client ${userIndex}`);
// }



const dir = path.join(".", `${args.testID}`);
await fs.mkdir(dir, { recursive: true });
const streamPath = path.join(dir, `client${userIndex}_local_log.jsonl`);
Expand All @@ -49,16 +66,25 @@
}

try{
for await (const log of disco.trainSummary(data)){
debug(`Starting training for client ${userIndex}`);
const trainStart = Date.now();
for await (const log of disco.trainSummary(data, validationData)){
finalLog.push(log);

if (jsonStream){
jsonStream.write(JSON.stringify(log) + "\n");
}
}
debug(`Training took ${Date.now() - trainStart}ms for client ${userIndex}`);

await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish

// Save the trained model if requested
if (args.saveModel) {
const modelDir = path.join(".", `${args.testID}`, "models");
const modelFileName = `client${userIndex}_model.json`;
await saveModelToDisk(disco.trainer.model, modelDir, modelFileName);
console.log(`Model saved for client ${userIndex} at ${modelDir}/${modelFileName}`);
}
// saving the entire per-user logs
if (args.save) {
const finalPath = path.join(dir, `client${userIndex}_local_log.json`);
Expand Down Expand Up @@ -104,10 +130,17 @@
console.log({ args })

const dataSplits = await Promise.all(
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers))
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers, args.datasetPath))
)

let validationData: Dataset<DataFormat.Raw[D]> | undefined = undefined;
if (args.validationDatasetPath) {
// Assume text task for now
validationData = loadText(args.validationDatasetPath).cached() as Dataset<DataFormat.Raw[D]>;
}

const logs = await Promise.all(
dataSplits.map((data, i) => runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>, i, numberOfUsers))
dataSplits.map((data, i) => runUser(task, provider, args.host, data as Dataset<DataFormat.Raw[D]>, validationData, i, numberOfUsers))
)

if (args.save) {
Expand Down
64 changes: 62 additions & 2 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,54 @@
import path from "node:path";
import { createReadStream } from "node:fs";
import { Dataset, processing } from "@epfml/discojs";
import {
DataFormat,
DataType,
Image,
Task,
Text,
} from "@epfml/discojs";
import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
import { loadCSV, loadImage, loadImagesInDir, loadText } from "@epfml/discojs-node";
import { Repeat } from "immutable";

function loadShardedTextSamples(
filePath: string,
userIdx: number,
totalClient: number,
): Dataset<Text> {
return new Dataset(async function* () {
const stream = createReadStream(filePath, { encoding: "utf8" });
const sampleDelimiter = "<|endoftext|>";
let buffer = "";
let sampleIndex = 0;

for await (const chunk of stream) {
if (typeof chunk !== "string") {
throw new Error("Expected file stream to yield string");
}

buffer += chunk;

let delimiterIndex = buffer.indexOf(sampleDelimiter);
while (delimiterIndex !== -1) {
const sample = buffer.slice(0, delimiterIndex + sampleDelimiter.length).trim();
if (sample !== "" && sampleIndex % totalClient === userIdx) {
yield sample;
}

sampleIndex++;
buffer = buffer.slice(delimiterIndex + sampleDelimiter.length);
delimiterIndex = buffer.indexOf(sampleDelimiter);
}
}

const trailingSample = buffer.trim();
if (trailingSample !== "" && sampleIndex % totalClient === userIdx) {
yield trailingSample;
}
});
}

async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise<Dataset<DataFormat.Raw["image"]>> {
const folder = path.join("..", "datasets", "simple_face");

Expand Down Expand Up @@ -94,7 +134,10 @@ function loadData(dataName: string, split: number): Dataset<DataFormat.Raw["imag
export async function getTaskData<D extends DataType>(
taskID: Task.ID,
userIdx: number,
totalClient: number
totalClient: number,
datasetPath?: string,
isValidation?: boolean,
validationDatasetPath?: string
): Promise<Dataset<DataFormat.Raw[D]>> {
switch (taskID) {
case "simple_face": // remove
Expand All @@ -118,6 +161,23 @@ export async function getTaskData<D extends DataType>(
case "mnist_federated":
case "mnist":
return loadData("mnist", userIdx) as Dataset<DataFormat.Raw[D]>;
case "privacyrun": {
const filePath =
isValidation && validationDatasetPath
? validationDatasetPath
: datasetPath ?? "../datasets/med_mcq/train.txt";

// Keep validation shared, but shard training data across clients by MCQ sample.
if (isValidation) {
return loadText(filePath) as Dataset<DataFormat.Raw[D]>;
}

return loadShardedTextSamples(
filePath,
userIdx,
totalClient,
) as Dataset<DataFormat.Raw[D]>;
}
default:
throw new Error(`Data loader for ${taskID} not implemented.`);
}
Expand Down
Loading
Loading