Skip to content

Commit 08e3dfd

Browse files
committed
[refactor] Standardize model configuration property names to snake_case across the codebase
- Renamed `modelPath` to `model_path` and `dType` to `dtype` in multiple files for consistency. - Updated related schemas, model samples, and tests to reflect the new naming convention. - Introduced a new `pooling` property in relevant schemas to specify the pooling strategy for models. - Ensured all references to model configuration are aligned with the updated naming for improved clarity and maintainability.
1 parent b52d595 commit 08e3dfd

23 files changed

+228
-172
lines changed

docs/developers/01_getting_started.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ await modelRepo.addModel({
154154
provider: HF_TRANSFORMERS_ONNX,
155155
provider_config: {
156156
pipeline: "text2text-generation",
157-
modelPath: "Xenova/LaMini-Flan-T5-783M"
157+
pooling: "mean",
158+
model_path: "Xenova/LaMini-Flan-T5-783M"
158159
});
159160

160161
// Job queue for the provider

packages/ai-provider/src/hf-transformers/common/HFT_JobRunFns.ts

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ export function clearPipelineCache(): void {
9090

9191
/**
9292
* Generate a cache key for a pipeline that includes all configuration options
93-
* that affect pipeline creation (modelPath, pipeline, dType, device)
93+
* that affect pipeline creation (model_path, pipeline, dtype, device)
9494
*/
9595
function getPipelineCacheKey(model: HfTransformersOnnxModelConfig): string {
96-
const dType = model.provider_config.dType || "q8";
96+
const dtype = model.provider_config.dtype || "q8";
9797
const device = model.provider_config.device || "";
98-
return `${model.provider_config.modelPath}:${model.provider_config.pipeline}:${dType}:${device}`;
98+
return `${model.provider_config.model_path}:${model.provider_config.pipeline}:${dtype}:${device}`;
9999
}
100100

101101
/**
@@ -379,9 +379,9 @@ const getPipeline = async (
379379
};
380380

381381
const pipelineOptions: PretrainedModelOptions = {
382-
dtype: model.provider_config.dType || "q8",
383-
...(model.provider_config.useExternalDataFormat
384-
? { use_external_data_format: model.provider_config.useExternalDataFormat }
382+
dtype: model.provider_config.dtype || "q8",
383+
...(model.provider_config.use_external_data_format
384+
? { useExternalDataFormat: model.provider_config.use_external_data_format }
385385
: {}),
386386
...(model.provider_config.device ? { device: model.provider_config.device as any } : {}),
387387
...options,
@@ -412,7 +412,7 @@ const getPipeline = async (
412412
});
413413

414414
// Race between pipeline creation and abort
415-
const pipelinePromise = pipeline(pipelineType, model.provider_config.modelPath, pipelineOptions);
415+
const pipelinePromise = pipeline(pipelineType, model.provider_config.model_path, pipelineOptions);
416416

417417
try {
418418
const result = await (abortSignal
@@ -471,8 +471,8 @@ export const HFT_Unload: AiProviderRunFn<
471471
}
472472

473473
// Delete model cache entries
474-
const modelPath = model!.provider_config.modelPath;
475-
await deleteModelCache(modelPath);
474+
const model_path = model!.provider_config.model_path;
475+
await deleteModelCache(model_path);
476476
onProgress(100, "Model cache deleted");
477477

478478
return {
@@ -482,12 +482,12 @@ export const HFT_Unload: AiProviderRunFn<
482482

483483
/**
484484
* Deletes all cache entries for a given model path
485-
* @param modelPath - The model path to delete from cache
485+
* @param model_path - The model path to delete from cache
486486
*/
487-
const deleteModelCache = async (modelPath: string): Promise<void> => {
487+
const deleteModelCache = async (model_path: string): Promise<void> => {
488488
const cache = await caches.open(HTF_CACHE_NAME);
489489
const keys = await cache.keys();
490-
const prefix = `/${modelPath}/`;
490+
const prefix = `/${model_path}/`;
491491

492492
// Collect all matching requests first
493493
const requestsToDelete: Request[] = [];
@@ -534,20 +534,20 @@ export const HFT_TextEmbedding: AiProviderRunFn<
534534

535535
// Generate the embedding
536536
const hfVector = await generateEmbedding(input.text, {
537-
pooling: "mean",
537+
pooling: model?.provider_config.pooling || "mean",
538538
normalize: model?.provider_config.normalize,
539539
...(signal ? { abort_signal: signal } : {}),
540540
});
541541

542542
// Validate the embedding dimensions
543-
if (hfVector.size !== model?.provider_config.nativeDimensions) {
543+
if (hfVector.size !== model?.provider_config.native_dimensions) {
544544
console.warn(
545-
`HuggingFace Embedding vector length does not match model dimensions v${hfVector.size} != m${model?.provider_config.nativeDimensions}`,
545+
`HuggingFace Embedding vector length does not match model dimensions v${hfVector.size} != m${model?.provider_config.native_dimensions}`,
546546
input,
547547
hfVector
548548
);
549549
throw new Error(
550-
`HuggingFace Embedding vector length does not match model dimensions v${hfVector.size} != m${model?.provider_config.nativeDimensions}`
550+
`HuggingFace Embedding vector length does not match model dimensions v${hfVector.size} != m${model?.provider_config.native_dimensions}`
551551
);
552552
}
553553

packages/ai-provider/src/hf-transformers/common/HFT_ModelSchema.ts

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ export const HfTransformersOnnxModelSchema = {
2525
description: "Pipeline type for the ONNX model.",
2626
default: "text-generation",
2727
},
28-
modelPath: {
28+
model_path: {
2929
type: "string",
3030
description: "Filesystem path or URI for the ONNX model.",
3131
},
32-
dType: {
32+
dtype: {
3333
type: "string",
3434
enum: Object.values(QuantizationDataType),
3535
description: "Data type for the ONNX model.",
@@ -41,42 +41,50 @@ export const HfTransformersOnnxModelSchema = {
4141
description: "High-level device selection.",
4242
default: "webgpu",
4343
},
44-
executionProviders: {
44+
execution_providers: {
4545
type: "array",
4646
items: { type: "string" },
4747
description: "Raw ONNX Runtime execution provider identifiers.",
4848
"x-ui-hidden": true,
4949
},
50-
intraOpNumThreads: {
50+
intra_op_num_threads: {
5151
type: "integer",
5252
minimum: 1,
5353
},
54-
interOpNumThreads: {
54+
inter_op_num_threads: {
5555
type: "integer",
5656
minimum: 1,
5757
},
58-
useExternalDataFormat: {
58+
use_external_data_format: {
5959
type: "boolean",
6060
description: "Whether the model uses external data format.",
6161
},
62-
nativeDimensions: {
62+
native_dimensions: {
6363
type: "integer",
6464
description: "The native dimensions of the model.",
6565
},
66+
pooling: {
67+
type: "string",
68+
enum: ["mean", "last_token", "cls"],
69+
description: "The pooling strategy to use for the model.",
70+
default: "mean",
71+
},
6672
normalize: {
6773
type: "boolean",
6874
description: "Whether the model uses normalization.",
75+
default: true,
6976
},
70-
languageStyle: {
77+
language_style: {
7178
type: "string",
7279
description: "The language style of the model.",
7380
},
7481
mrl: {
7582
type: "boolean",
7683
description: "Whether the model uses matryoshka.",
84+
default: false,
7785
},
7886
},
79-
required: ["modelPath", "pipeline"],
87+
required: ["model_path", "pipeline"],
8088
additionalProperties: false,
8189
if: {
8290
properties: {
@@ -86,7 +94,7 @@ export const HfTransformersOnnxModelSchema = {
8694
},
8795
},
8896
then: {
89-
required: ["nativeDimensions"],
97+
required: ["native_dimensions"],
9098
},
9199
},
92100
},

packages/ai-provider/src/tf-mediapipe/common/TFMP_JobRunFns.ts

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ const getWasmTask = async (
8686
onProgress: (progress: number, message?: string, details?: any) => void,
8787
signal: AbortSignal
8888
): Promise<TFMPWasmFileset> => {
89-
const taskEngine = model.provider_config.taskEngine;
89+
const task_engine = model.provider_config.task_engine;
9090

91-
if (wasm_tasks.has(taskEngine)) {
92-
return wasm_tasks.get(taskEngine)!;
91+
if (wasm_tasks.has(task_engine)) {
92+
return wasm_tasks.get(task_engine)!;
9393
}
9494

9595
if (signal.aborted) {
@@ -100,7 +100,7 @@ const getWasmTask = async (
100100

101101
let wasmFileset: TFMPWasmFileset;
102102

103-
switch (taskEngine) {
103+
switch (task_engine) {
104104
case "text":
105105
wasmFileset = await FilesetResolver.forTextTasks(
106106
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm"
@@ -125,7 +125,7 @@ const getWasmTask = async (
125125
throw new PermanentJobError("Invalid task engine");
126126
}
127127

128-
wasm_tasks.set(taskEngine, wasmFileset);
128+
wasm_tasks.set(task_engine, wasmFileset);
129129
return wasmFileset;
130130
};
131131

@@ -160,7 +160,7 @@ type TaskInstance =
160160
interface CachedModelTask {
161161
readonly task: TaskInstance;
162162
readonly options: Record<string, unknown>;
163-
readonly taskEngine: string;
163+
readonly task_engine: string;
164164
}
165165

166166
const modelTaskCache = new Map<string, CachedModelTask[]>();
@@ -219,11 +219,11 @@ const getModelTask = async <T extends TaskType>(
219219
signal: AbortSignal,
220220
TaskType: T
221221
): Promise<InferTaskInstance<T>> => {
222-
const modelPath = model.provider_config.modelPath;
223-
const taskEngine = model.provider_config.taskEngine;
222+
const model_path = model.provider_config.model_path;
223+
const task_engine = model.provider_config.task_engine;
224224

225225
// Check if we have a cached instance with matching options
226-
const cachedTasks = modelTaskCache.get(modelPath);
226+
const cachedTasks = modelTaskCache.get(model_path);
227227
if (cachedTasks) {
228228
const matchedTask = cachedTasks.find((cached) => optionsMatch(cached.options, options));
229229
if (matchedTask) {
@@ -239,20 +239,20 @@ const getModelTask = async <T extends TaskType>(
239239
// Create new model instance
240240
const task = await TaskType.createFromOptions(wasmFileset, {
241241
baseOptions: {
242-
modelAssetPath: modelPath,
242+
modelAssetPath: model_path,
243243
},
244244
...options,
245245
});
246246

247247
// Cache the task with its options and task engine
248-
const cachedTask: CachedModelTask = { task, options, taskEngine };
249-
if (!modelTaskCache.has(modelPath)) {
250-
modelTaskCache.set(modelPath, []);
248+
const cachedTask: CachedModelTask = { task, options, task_engine };
249+
if (!modelTaskCache.has(model_path)) {
250+
modelTaskCache.set(model_path, []);
251251
}
252-
modelTaskCache.get(modelPath)!.push(cachedTask);
252+
modelTaskCache.get(model_path)!.push(cachedTask);
253253

254254
// Increment WASM reference count for this cached task
255-
wasm_reference_counts.set(taskEngine, (wasm_reference_counts.get(taskEngine) || 0) + 1);
255+
wasm_reference_counts.set(task_engine, (wasm_reference_counts.get(task_engine) || 0) + 1);
256256

257257
return task as any;
258258
};
@@ -314,8 +314,8 @@ export const TFMP_Download: AiProviderRunFn<
314314
onProgress(0.9, "Pipeline loaded");
315315
task.close(); // Close the task to release the resources, but it is still in the browser cache
316316
// Decrease reference count for WASM fileset for this cached task since this is a fake model cache entry
317-
const taskEngine = model?.provider_config.taskEngine;
318-
wasm_reference_counts.set(taskEngine, wasm_reference_counts.get(taskEngine)! - 1);
317+
const task_engine = model?.provider_config.task_engine;
318+
wasm_reference_counts.set(task_engine, wasm_reference_counts.get(task_engine)! - 1);
319319

320320
return {
321321
model: input.model,
@@ -435,31 +435,31 @@ export const TFMP_Unload: AiProviderRunFn<
435435
UnloadModelTaskExecuteOutput,
436436
TFMPModelConfig
437437
> = async (input, model, onProgress, signal) => {
438-
const modelPath = model!.provider_config.modelPath;
438+
const model_path = model!.provider_config.model_path;
439439
onProgress(10, "Unloading model");
440440
// Dispose of all cached model tasks if they exist
441-
if (modelTaskCache.has(modelPath)) {
442-
const cachedTasks = modelTaskCache.get(modelPath)!;
441+
if (modelTaskCache.has(model_path)) {
442+
const cachedTasks = modelTaskCache.get(model_path)!;
443443

444444
for (const cachedTask of cachedTasks) {
445445
const task = cachedTask.task;
446446
if ("close" in task && typeof task.close === "function") task.close();
447447

448448
// Decrease reference count for WASM fileset for this cached task
449-
const taskEngine = cachedTask.taskEngine;
450-
const currentCount = wasm_reference_counts.get(taskEngine) || 0;
449+
const task_engine = cachedTask.task_engine;
450+
const currentCount = wasm_reference_counts.get(task_engine) || 0;
451451
const newCount = currentCount - 1;
452452

453453
if (newCount <= 0) {
454454
// No more models using this WASM fileset, unload it
455-
wasm_tasks.delete(taskEngine);
456-
wasm_reference_counts.delete(taskEngine);
455+
wasm_tasks.delete(task_engine);
456+
wasm_reference_counts.delete(task_engine);
457457
} else {
458-
wasm_reference_counts.set(taskEngine, newCount);
458+
wasm_reference_counts.set(task_engine, newCount);
459459
}
460460
}
461461

462-
modelTaskCache.delete(modelPath);
462+
modelTaskCache.delete(model_path);
463463
}
464464

465465
return {

packages/ai-provider/src/tf-mediapipe/common/TFMP_ModelSchema.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ export const TFMPModelSchema = {
1919
type: "object",
2020
description: "TensorFlow MediaPipe-specific options.",
2121
properties: {
22-
modelPath: {
22+
model_path: {
2323
type: "string",
2424
description: "Filesystem path or URI for the ONNX model.",
2525
},
26-
taskEngine: {
26+
task_engine: {
2727
type: "string",
2828
enum: ["text", "audio", "vision", "genai"],
2929
description: "Task engine for the MediaPipe model.",
@@ -34,7 +34,7 @@ export const TFMPModelSchema = {
3434
description: "Pipeline task type for the MediaPipe model.",
3535
},
3636
},
37-
required: ["modelPath", "taskEngine", "pipeline"],
37+
required: ["model_path", "task_engine", "pipeline"],
3838
additionalProperties: false,
3939
},
4040
},

packages/ai/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ await modelRepo.addModel({
5252
provider: HF_TRANSFORMERS_ONNX,
5353
provider_config: {
5454
pipeline: "text2text-generation",
55-
modelPath: "Xenova/LaMini-Flan-T5-783M"
55+
model_path: "Xenova/LaMini-Flan-T5-783M"
5656
});
5757

5858
// 3. Register provider functions (inline, same thread)

packages/tasks/package.json

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,27 @@
55
"description": "Pre-built task implementations for Workglow, including common AI operations and utility functions.",
66
"scripts": {
77
"watch": "concurrently -c 'auto' 'bun:watch-*'",
8-
"watch-code": "bun build --watch --no-clear-screen --target=browser --sourcemap=external --packages=external --outdir ./dist ./src/index.ts",
8+
"watch-browser": "bun build --watch --no-clear-screen --target=browser --sourcemap=external --packages=external --outdir ./dist ./src/browser.ts",
9+
"watch-node": "bun build --watch --no-clear-screen --target=node --sourcemap=external --packages=external --outdir ./dist ./src/node.ts",
10+
"watch-bun": "bun build --watch --no-clear-screen --target=bun --sourcemap=external --packages=external --outdir ./dist ./src/bun.ts",
911
"watch-types": "tsc --watch --preserveWatchOutput",
10-
"build-package": "bun run build-clean && concurrently -c 'auto' -n 'code,types' 'bun run build-code' 'bun run build-types'",
12+
"build-package": "bun run build-clean && concurrently -c 'auto' -n 'browser,node,bun,types' 'bun run build-browser' 'bun run build-node' 'bun run build-bun' 'bun run build-types'",
1113
"build-clean": "rm -fr dist/*",
12-
"build-code": "bun build --target=browser --sourcemap=external --packages=external --outdir ./dist ./src/index.ts",
14+
"build-browser": "bun build --target=browser --sourcemap=external --packages=external --outdir ./dist ./src/browser.ts",
15+
"build-node": "bun build --target=node --sourcemap=external --packages=external --outdir ./dist ./src/node.ts",
16+
"build-bun": "bun build --target=bun --sourcemap=external --packages=external --outdir ./dist ./src/bun.ts",
1317
"build-types": "rm -f tsconfig.tsbuildinfo && tsc",
1418
"lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
1519
"test": "bun test",
16-
"prepare": "node -e \"const pkg=require('./package.json');pkg.exports['.'].bun='./dist/index.js';pkg.exports['.'].types='./dist/types.d.ts';require('fs').writeFileSync('package.json',JSON.stringify(pkg,null,2))\""
20+
"prepare": "node -e \"const pkg=require('./package.json');pkg.exports['.'].bun='./dist/bun.js';pkg.exports['.'].types='./dist/types.d.ts';require('fs').writeFileSync('package.json',JSON.stringify(pkg,null,2))\""
1721
},
1822
"exports": {
1923
".": {
20-
"bun": "./dist/index.js",
24+
"react-native": "./dist/browser.js",
25+
"browser": "./dist/browser.js",
26+
"bun": "./dist/bun.js",
2127
"types": "./dist/types.d.ts",
22-
"import": "./dist/index.js"
28+
"import": "./dist/node.js"
2329
}
2430
},
2531
"files": [

packages/tasks/src/browser.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
/**
2+
* @license
3+
* Copyright 2025 Steven Roussey <sroussey@gmail.com>
4+
* SPDX-License-Identifier: Apache-2.0
5+
*/
6+
7+
export * from "./common";
8+
export * from "./task/FileLoaderTask";

packages/tasks/src/bun.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
/**
2+
* @license
3+
* Copyright 2025 Steven Roussey <sroussey@gmail.com>
4+
* SPDX-License-Identifier: Apache-2.0
5+
*/
6+
7+
export * from "./common";
8+
export * from "./task/FileLoaderTask.server";

0 commit comments

Comments
 (0)