Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ export * from './shared/types';
// The factory will handle environment-specific provider selection

export type {
LiteLLMAuthType,
LiteLLMOutputMode,
LiteLLMProviderConfig,
OAuth2Config,
} from './node/litellm';
// Export LiteLLM provider for direct instantiation
export { LiteLLMProvider } from './node/litellm';
Expand Down
362 changes: 362 additions & 0 deletions src/node/litellm.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,368 @@ describe('LiteLLMProvider', () => {
});
});

describe('OAuth2 Authentication', () => {
test('should throw error when authType is oauth2 but no config provided', () => {
expect(
() =>
new LiteLLMProvider({
authType: 'oauth2',
}),
).toThrow('OAuth2 configuration required');
});

test('should accept OAuth2 configuration from constructor', () => {
const provider = new LiteLLMProvider({
authType: 'oauth2',
oauth2: {
tokenUrl: 'https://auth.example.com/token',
clientId: 'test-client',
clientSecret: 'test-secret',
},
});

expect(provider.name).toBe('litellm');
});

test('should load OAuth2 configuration from environment variables', () => {
process.env.HAVE_OCR_LITELLM_AUTH_TYPE = 'oauth2';
process.env.HAVE_OCR_LITELLM_OAUTH2_TOKEN_URL =
'https://auth.example.com/token';
process.env.HAVE_OCR_LITELLM_OAUTH2_CLIENT_ID = 'env-client';
process.env.HAVE_OCR_LITELLM_OAUTH2_CLIENT_SECRET = 'env-secret';
process.env.HAVE_OCR_LITELLM_OAUTH2_SCOPES = 'read,write';

const provider = new LiteLLMProvider();
expect(provider.name).toBe('litellm');
});

test('should prefer constructor OAuth2 config over environment variables', () => {
process.env.HAVE_OCR_LITELLM_AUTH_TYPE = 'oauth2';
process.env.HAVE_OCR_LITELLM_OAUTH2_TOKEN_URL =
'https://env.example.com/token';
process.env.HAVE_OCR_LITELLM_OAUTH2_CLIENT_ID = 'env-client';
process.env.HAVE_OCR_LITELLM_OAUTH2_CLIENT_SECRET = 'env-secret';

// Constructor config should take precedence
const provider = new LiteLLMProvider({
authType: 'oauth2',
oauth2: {
tokenUrl: 'https://constructor.example.com/token',
clientId: 'constructor-client',
clientSecret: 'constructor-secret',
},
});

expect(provider.name).toBe('litellm');
});

test('should report OAuth2 configured in dependency check', async () => {
// Mock fetch for OAuth2 token
const mockFetch = vi.spyOn(globalThis, 'fetch').mockResolvedValueOnce({
ok: true,
json: () =>
Promise.resolve({
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
access_token: 'check-deps-token',
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
expires_in: 3600,
}),
} as Response);

const provider = new LiteLLMProvider({
authType: 'oauth2',
oauth2: {
tokenUrl: 'https://auth.example.com/token',
clientId: 'test-client',
clientSecret: 'test-secret',
},
});

const deps = await provider.checkDependencies();
expect(deps.available).toBe(true);
expect(deps.details.authType).toBe('oauth2');
expect(deps.details.oauth2Configured).toBe(true);

mockFetch.mockRestore();
});

test('should report API key auth type in dependency check', async () => {
const provider = new LiteLLMProvider({
apiKey: 'test-key',
});

const deps = await provider.checkDependencies();
expect(deps.available).toBe(true);
expect(deps.details.authType).toBe('api_key');
expect(deps.details.apiKey).toBe(true);
});

test('should fetch OAuth2 token and use for API calls', async () => {
// Mock fetch for OAuth2 token
const mockFetch = vi.spyOn(globalThis, 'fetch').mockResolvedValueOnce({
ok: true,
json: () =>
Promise.resolve({
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
access_token: 'mock-oauth-token',
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
expires_in: 3600,
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
token_type: 'Bearer',
}),
} as Response);

const { getAI } = await import('@happyvertical/ai');
vi.mocked(getAI).mockResolvedValue({
chat: vi.fn().mockResolvedValue({
content: 'OAuth2 extracted text',
}),
} as any);

const provider = new LiteLLMProvider({
authType: 'oauth2',
oauth2: {
tokenUrl: 'https://auth.example.com/token',
clientId: 'test-client',
clientSecret: 'test-secret',
},
});

const pngBuffer = Buffer.from([
0x89,
0x50,
0x4e,
0x47,
0x0d,
0x0a,
0x1a,
0x0a,
...Array(100).fill(0),
]);

const result = await provider.performOCR([{ data: pngBuffer }]);

expect(result.text).toBe('OAuth2 extracted text');
expect(mockFetch).toHaveBeenCalledWith(
'https://auth.example.com/token',
expect.objectContaining({
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
}),
);

// Verify getAI was called with the OAuth2 token
expect(getAI).toHaveBeenCalledWith(
expect.objectContaining({
apiKey: 'mock-oauth-token',
}),
);

mockFetch.mockRestore();
});

test('should cache OAuth2 token and reuse it', async () => {
const mockFetch = vi.spyOn(globalThis, 'fetch').mockResolvedValue({
ok: true,
json: () =>
Promise.resolve({
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
access_token: 'cached-token',
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
expires_in: 3600,
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
token_type: 'Bearer',
}),
} as Response);

const { getAI } = await import('@happyvertical/ai');
vi.mocked(getAI).mockResolvedValue({
chat: vi.fn().mockResolvedValue({
content: 'Cached token text',
}),
} as any);

const provider = new LiteLLMProvider({
authType: 'oauth2',
oauth2: {
tokenUrl: 'https://auth.example.com/token',
clientId: 'test-client',
clientSecret: 'test-secret',
},
});

const pngBuffer = Buffer.from([
0x89,
0x50,
0x4e,
0x47,
0x0d,
0x0a,
0x1a,
0x0a,
...Array(100).fill(0),
]);

// First call - should fetch token
await provider.performOCR([{ data: pngBuffer }]);
expect(mockFetch).toHaveBeenCalledTimes(1);

// Second call - should use cached token (fetch not called again)
await provider.performOCR([{ data: pngBuffer }]);
// Note: getAI is called again but fetch should not be called
// since the token is cached. However, our implementation
// creates a new AI client each time for OAuth2 to ensure fresh tokens.
// The token fetch itself should be cached.
expect(mockFetch).toHaveBeenCalledTimes(1);

mockFetch.mockRestore();
});

test('should handle OAuth2 token fetch failure', async () => {
const mockFetch = vi.spyOn(globalThis, 'fetch').mockResolvedValueOnce({
ok: false,
status: 401,
statusText: 'Unauthorized',
text: () => Promise.resolve('Invalid credentials'),
} as Response);

const provider = new LiteLLMProvider({
authType: 'oauth2',
oauth2: {
tokenUrl: 'https://auth.example.com/token',
clientId: 'bad-client',
clientSecret: 'bad-secret',
},
});

const pngBuffer = Buffer.from([
0x89,
0x50,
0x4e,
0x47,
0x0d,
0x0a,
0x1a,
0x0a,
...Array(100).fill(0),
]);

await expect(provider.performOCR([{ data: pngBuffer }])).rejects.toThrow(
'OAuth2 token request failed',
);

mockFetch.mockRestore();
});

test('should include scopes in OAuth2 token request', async () => {
const mockFetch = vi.spyOn(globalThis, 'fetch').mockResolvedValueOnce({
ok: true,
json: () =>
Promise.resolve({
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
access_token: 'scoped-token',
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
expires_in: 3600,
}),
} as Response);

const { getAI } = await import('@happyvertical/ai');
vi.mocked(getAI).mockResolvedValue({
chat: vi.fn().mockResolvedValue({
content: 'Scoped text',
}),
} as any);

const provider = new LiteLLMProvider({
authType: 'oauth2',
oauth2: {
tokenUrl: 'https://auth.example.com/token',
clientId: 'test-client',
clientSecret: 'test-secret',
scopes: ['ocr.read', 'ocr.write'],
},
});

const pngBuffer = Buffer.from([
0x89,
0x50,
0x4e,
0x47,
0x0d,
0x0a,
0x1a,
0x0a,
...Array(100).fill(0),
]);

await provider.performOCR([{ data: pngBuffer }]);

// Check that fetch was called with scope parameter
expect(mockFetch).toHaveBeenCalledWith(
'https://auth.example.com/token',
expect.objectContaining({
body: expect.stringContaining('scope=ocr.read+ocr.write'),
}),
);

mockFetch.mockRestore();
});

test('should clear cached token on cleanup', async () => {
const mockFetch = vi.spyOn(globalThis, 'fetch').mockResolvedValue({
ok: true,
json: () =>
Promise.resolve({
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
access_token: 'cleanup-token',
// biome-ignore lint/style/useNamingConvention: OAuth2 standard
expires_in: 3600,
}),
} as Response);

const { getAI } = await import('@happyvertical/ai');
vi.mocked(getAI).mockResolvedValue({
chat: vi.fn().mockResolvedValue({
content: 'Cleanup text',
}),
} as any);

const provider = new LiteLLMProvider({
authType: 'oauth2',
oauth2: {
tokenUrl: 'https://auth.example.com/token',
clientId: 'test-client',
clientSecret: 'test-secret',
},
});

const pngBuffer = Buffer.from([
0x89,
0x50,
0x4e,
0x47,
0x0d,
0x0a,
0x1a,
0x0a,
...Array(100).fill(0),
]);

// First call - fetch token
await provider.performOCR([{ data: pngBuffer }]);
expect(mockFetch).toHaveBeenCalledTimes(1);

// Cleanup should clear cached token
await provider.cleanup();

// Next call should fetch a new token
await provider.performOCR([{ data: pngBuffer }]);
expect(mockFetch).toHaveBeenCalledTimes(2);

mockFetch.mockRestore();
});
});

describe('Image Format Detection', () => {
test('should detect PNG format', async () => {
const provider = new LiteLLMProvider({ apiKey: 'test-key' });
Expand Down
Loading