Skip to content

Commit

Permalink
ability to use custom cacheManager (#109)
Browse files Browse the repository at this point in the history
* ability to use custom cacheManager

* v1.16.0

* correct method modifier
  • Loading branch information
ngxson authored Aug 19, 2024
1 parent 708f95f commit 3b20fed
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 32 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@wllama/wllama",
"version": "1.15.0",
"version": "1.16.0",
"description": "Low-level WASM binding for llama.cpp",
"main": "index.js",
"type": "module",
Expand Down
40 changes: 21 additions & 19 deletions src/cache-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ export interface CacheEntryMetadata {
/**
* Cache implementation using OPFS (Origin private file system)
*/
export const CacheManager = {
class CacheManager {
/**
* Convert a given URL into file name in cache.
*
* Format of the file name: `${hashSHA1(fullURL)}_${fileName}`
*/
async getNameFromURL(url: string): Promise<string> {
return await toFileName(url, '');
},
}

/**
* Write a new file to cache. This will overwrite existing file.
Expand All @@ -59,9 +59,9 @@ export const CacheManager = {
stream: ReadableStream,
metadata: CacheEntryMetadata
): Promise<void> {
CacheManager._writeMetadata(name, metadata); // no need await
this._writeMetadata(name, metadata); // no need await
return await opfsWrite(name, stream);
},
}

/**
* Open a file in cache for reading
Expand All @@ -71,7 +71,7 @@ export const CacheManager = {
*/
async open(name: string): Promise<ReadableStream | null> {
return await opfsOpen(name);
},
}

/**
* Get the size of a file in stored cache
Expand All @@ -83,14 +83,14 @@ export const CacheManager = {
*/
async getSize(name: string): Promise<number> {
return await opfsFileSize(name);
},
}

/**
* Get metadata of a cached file
*/
async getMetadata(name: string): Promise<CacheEntryMetadata | null> {
const stream = await opfsOpen(name, PREFIX_METADATA);
const cachedSize = await CacheManager.getSize(name);
const cachedSize = await this.getSize(name);
if (!stream) {
return cachedSize > 0
? // files created by older version of wllama doesn't have metadata, we will try to polyfill it
Expand All @@ -109,7 +109,7 @@ export const CacheManager = {
// worst case: metadata is somehow corrupted, we will re-download the model
return null;
}
},
}

/**
* List all files currently in cache
Expand Down Expand Up @@ -147,26 +147,26 @@ export const CacheManager = {
}
}
return result;
},
}

/**
* Clear all files currently in cache
*/
async clear(): Promise<void> {
await CacheManager.deleteMany(() => true);
},
await this.deleteMany(() => true);
}

/**
* Delete a single file in cache
*
* @param nameOrURL Can be either an URL or a name returned by `getNameFromURL()` or `list()`
*/
async delete(nameOrURL: string): Promise<void> {
const name2 = await CacheManager.getNameFromURL(nameOrURL);
await CacheManager.deleteMany(
const name2 = await this.getNameFromURL(nameOrURL);
await this.deleteMany(
(entry) => entry.name === nameOrURL || entry.name === name2
);
},
}

/**
* Delete multiple files in cache.
Expand All @@ -175,25 +175,27 @@ export const CacheManager = {
*/
async deleteMany(predicate: (e: CacheEntry) => boolean): Promise<void> {
const cacheDir = await getCacheDir();
const list = await CacheManager.list();
const list = await this.list();
for (const item of list) {
if (predicate(item)) {
cacheDir.removeEntry(item.name);
}
}
},
}

/**
* Internally used
*/
async _writeMetadata(
private async _writeMetadata(
name: string,
metadata: CacheEntryMetadata
): Promise<void> {
const blob = new Blob([JSON.stringify(metadata)], { type: 'text/plain' });
await opfsWrite(name, blob.stream(), PREFIX_METADATA);
},
};
}
}

export default CacheManager;

/**
* Write to OPFS file from ReadableStream
Expand Down
5 changes: 5 additions & 0 deletions src/downloader/multi-downloads.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import CacheManager from '../cache-manager';
import { GGUFRemoteBlob } from './remote-blob';

type ProgressCallback = (opts: { loaded: number; total: number }) => any;
Expand Down Expand Up @@ -26,11 +27,13 @@ export class MultiDownloads {
private totalBytes: number = 0;
private allowOffline: boolean;
private noTEE: boolean;
private cacheManager: CacheManager;

constructor(
logger: any,
urls: string[],
maxParallel: number,
cacheManager: CacheManager,
opts: {
progressCallback?: ProgressCallback;
useCache: boolean;
Expand All @@ -55,6 +58,7 @@ export class MultiDownloads {
this.useCache = opts.useCache;
this.allowOffline = opts.allowOffline;
this.noTEE = !!opts.noTEE;
this.cacheManager = cacheManager;
}

async run(): Promise<Blob[]> {
Expand All @@ -67,6 +71,7 @@ export class MultiDownloads {
startSignal: task.signalStart,
allowOffline: this.allowOffline,
noTEE: this.noTEE,
cacheManager: this.cacheManager,
progressCallback: ({ loaded }) => {
task.loaded = loaded;
this.updateProgress(task);
Expand Down
22 changes: 14 additions & 8 deletions src/downloader/remote-blob.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// Adapted from https://github.com/huggingface/huggingface.js/blob/main/packages/hub/src/utils/WebBlob.ts

import {
import CacheManager, {
CacheEntryMetadata,
CacheManager,
POLYFILL_ETAG,
} from '../cache-manager';

Expand All @@ -21,6 +20,7 @@ interface GGUFRemoteBlobCreateOptions {
progressCallback?: ProgressCallback;
startSignal?: Promise<void>;
allowOffline: boolean;
cacheManager: CacheManager;
/**
* Should we skip TEE the output stream?
* Set to true if we only download model to cache, without reading it
Expand All @@ -39,6 +39,7 @@ export class GGUFRemoteBlob extends Blob {
url: string,
opts: GGUFRemoteBlobCreateOptions
): Promise<Blob> {
const { cacheManager } = opts;
const customFetch = opts?.fetch ?? fetch;
const cacheKey = url;
let remoteFile: CacheEntryMetadata;
Expand All @@ -54,7 +55,7 @@ export class GGUFRemoteBlob extends Blob {
} catch (err) {
// connection error (i.e. offline)
if (opts.allowOffline) {
const cachedMeta = await CacheManager.getMetadata(cacheKey);
const cachedMeta = await cacheManager.getMetadata(cacheKey);
if (cachedMeta) {
remoteFile = cachedMeta;
} else {
Expand All @@ -67,14 +68,14 @@ export class GGUFRemoteBlob extends Blob {
}
}

const cachedFileSize = await CacheManager.getSize(cacheKey);
const cachedFile = await CacheManager.getMetadata(cacheKey);
const cachedFileSize = await cacheManager.getSize(cacheKey);
const cachedFile = await cacheManager.getMetadata(cacheKey);
const skipCache = opts?.useCache === false;

// migrate from old version: if metadata is polyfilled, we save the new metadata
const metadataPolyfilled = cachedFile?.etag === POLYFILL_ETAG;
if (metadataPolyfilled) {
await CacheManager._writeMetadata(cacheKey, remoteFile);
await cacheManager._writeMetadata(cacheKey, remoteFile);

Check failure on line 78 in src/downloader/remote-blob.ts

View workflow job for this annotation

GitHub Actions / build

Property '_writeMetadata' is private and only accessible within class 'CacheManager'.
}

const cachedFileValid =
Expand All @@ -84,7 +85,7 @@ export class GGUFRemoteBlob extends Blob {
remoteFile.originalSize === cachedFileSize);
if (cachedFileValid && !skipCache) {
opts?.logger?.debug(`Using cached file ${cacheKey}`);
const cachedFile = await CacheManager.open(cacheKey);
const cachedFile = await cacheManager.open(cacheKey);
(opts?.startSignal ?? Promise.resolve()).then(() => {
opts?.progressCallback?.({
loaded: cachedFileSize,
Expand All @@ -102,6 +103,7 @@ export class GGUFRemoteBlob extends Blob {
progressCallback: () => {}, // unused
etag: remoteFile.etag,
noTEE: opts.noTEE,
cacheManager: cacheManager,
}
);
} else {
Expand All @@ -127,11 +129,13 @@ export class GGUFRemoteBlob extends Blob {
startSignal: opts?.startSignal,
etag: remoteFile.etag,
noTEE: opts.noTEE,
cacheManager: cacheManager,
}
);
}
}

private cacheManager: CacheManager;
private url: string;
private etag: string;
private start: number;
Expand All @@ -156,6 +160,7 @@ export class GGUFRemoteBlob extends Blob {
startSignal?: Promise<void>;
etag: string;
noTEE: boolean;
cacheManager: CacheManager;
}
) {
super([]);
Expand All @@ -175,6 +180,7 @@ export class GGUFRemoteBlob extends Blob {
this.startSignal = additionals.startSignal;
this.etag = additionals.etag;
this.noTEE = additionals.noTEE;
this.cacheManager = additionals.cacheManager;
}

override get size(): number {
Expand Down Expand Up @@ -233,7 +239,7 @@ export class GGUFRemoteBlob extends Blob {
.then((response) => {
const [src0, src1] = response.body!.tee();
src0.pipeThrough(stream);
CacheManager.write(this.url, src1, {
this.cacheManager.write(this.url, src1, {
originalSize: this.end,
originalURL: this.url,
etag: this.etag,
Expand Down
14 changes: 10 additions & 4 deletions src/wllama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
maybeSortFileByName,
padDigits,
} from './utils';
import { CacheManager } from './cache-manager';
import CacheManager from './cache-manager';
import { MultiDownloads } from './downloader/multi-downloads';

export interface WllamaConfig {
Expand All @@ -25,6 +25,10 @@ export interface WllamaConfig {
warn: typeof console.warn;
error: typeof console.error;
};
/**
* Custom cache manager (only for advanced usage)
*/
cacheManager?: CacheManager;
}

export interface AssetsPathConfig {
Expand Down Expand Up @@ -154,7 +158,7 @@ export const LoggerWithoutDebug = {

export class Wllama {
// The CacheManager singleton, can be accessed by user
public cacheManager = CacheManager;
public cacheManager: CacheManager;

private proxy: ProxyToWorker = null as any;
private config: WllamaConfig;
Expand All @@ -178,6 +182,7 @@ export class Wllama {
if (!pathConfig) throw new Error('AssetsPathConfig is required');
this.pathConfig = pathConfig;
this.config = wllamaConfig;
this.cacheManager = wllamaConfig.cacheManager ?? new CacheManager();
}

private logger() {
Expand Down Expand Up @@ -364,6 +369,7 @@ export class Wllama {
this.logger(),
this.parseModelUrl(modelUrl),
config.parallelDownloads ?? 3,
this.cacheManager,
{
progressCallback: config.progressCallback,
useCache: true,
Expand Down Expand Up @@ -404,6 +410,7 @@ export class Wllama {
this.logger(),
this.parseModelUrl(modelUrl),
config.parallelDownloads ?? 3,
this.cacheManager,
{
progressCallback: config.progressCallback,
useCache: !skipCache,
Expand Down Expand Up @@ -923,7 +930,6 @@ export class Wllama {
return await this.proxy.wllamaDebug();
}


///// Prompt cache utils /////
private async getCachedToken(): Promise<number[]> {
this.checkModelLoaded();
Expand All @@ -944,7 +950,7 @@ export class Wllama {
}
}
const nDiscard = cachedTokens.length - nKeep;
this.logger().debug(`Cache nKeep=${nKeep} nDiscard=${nDiscard}`)
this.logger().debug(`Cache nKeep=${nKeep} nDiscard=${nDiscard}`);
await this.kvRemove(nKeep, nDiscard);
return seq.slice(nKeep, seq.length);
}
Expand Down

0 comments on commit 3b20fed

Please sign in to comment.