forked from rectanglehq/Shapeshift
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindex.ts
160 lines (140 loc) · 5.62 KB
/
index.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
// Library to map arbitrarily strucutred JSON objects using vector embeddings
import { CohereClient } from 'cohere-ai';
import OpenAI from "openai";
import axios from 'axios';
// Define types
type Embedding = number[];
type ObjectWithStringKeys = { [key: string]: any };
type ShapeshiftOptions = {
embeddingModel?: string;
similarityThreshold?: number;
};
type EmbeddingClient = 'cohere' | 'openai' | 'voyage';
export class Shapeshift {
private cohere: CohereClient | null;
private openai: OpenAI | null;
private voyageApiKey: string | null;
private embeddingClient: EmbeddingClient;
private embeddingModel: string;
private similarityThreshold: number;
constructor({ embeddingClient, apiKey }: { embeddingClient: EmbeddingClient, apiKey: string }, options: ShapeshiftOptions = {}) {
this.embeddingClient = embeddingClient;
if (embeddingClient === 'cohere') {
this.cohere = new CohereClient({ token: apiKey });
this.openai = null;
this.voyageApiKey = null;
this.embeddingModel = options.embeddingModel || 'embed-english-v3.0';
} else if (embeddingClient === 'openai') {
this.openai = new OpenAI({ apiKey });
this.cohere = null;
this.voyageApiKey = null;
this.embeddingModel = options.embeddingModel || 'text-embedding-ada-002';
} else if (embeddingClient === 'voyage') {
this.voyageApiKey = apiKey;
this.cohere = null;
this.openai = null;
this.embeddingModel = options.embeddingModel || 'voyage-large-2';
} else {
throw new Error('Unsupported embedding client');
}
this.similarityThreshold = options.similarityThreshold || 0.5; // Default threshold of 0.5
}
private async calculateEmbeddings(texts: string[]): Promise<Embedding[]> {
if (this.embeddingClient === 'cohere' && this.cohere) {
const response = await this.cohere.embed({
texts: texts,
model: this.embeddingModel,
inputType: 'classification',
});
return response.embeddings as Embedding[];
} else if (this.embeddingClient === 'openai' && this.openai) {
const embeddings = await Promise.all(texts.map(async (text) => {
const response = await this.openai!.embeddings.create({
model: this.embeddingModel,
input: text,
encoding_format: "float",
});
return response.data[0].embedding;
}));
return embeddings;
} else if (this.embeddingClient === 'voyage' && this.voyageApiKey) {
const response = await axios.post(
'https://api.voyageai.com/v1/embeddings',
{
input: texts,
model: this.embeddingModel,
},
{
headers: {
'Authorization': `Bearer ${this.voyageApiKey}`,
'Content-Type': 'application/json',
},
}
);
return response.data.data.map((item: any) => item.embedding);
} else {
throw new Error('Embedding client not properly initialized');
}
}
private cosineSimilarity(vecA: Embedding, vecB: Embedding): number {
const dotProduct = vecA.reduce((sum, a, i) => sum + a * vecB[i], 0);
const magnitudeA = Math.sqrt(vecA.reduce((sum, a) => sum + a * a, 0));
const magnitudeB = Math.sqrt(vecB.reduce((sum, b) => sum + b * b, 0));
return dotProduct / (magnitudeA * magnitudeB);
}
private findClosestMatch(sourceEmbedding: Embedding, targetEmbeddings: Embedding[]): number | null {
let maxSimilarity = -Infinity;
let closestIndex = -1;
for (let i = 0; i < targetEmbeddings.length; i++) {
const similarity = this.cosineSimilarity(sourceEmbedding, targetEmbeddings[i]);
if (similarity > maxSimilarity) {
maxSimilarity = similarity;
closestIndex = i;
}
}
return maxSimilarity >= this.similarityThreshold ? closestIndex : null;
}
private flattenObject(obj: ObjectWithStringKeys, prefix = ''): ObjectWithStringKeys {
return Object.keys(obj).reduce((acc: ObjectWithStringKeys, k) => {
const pre = prefix.length ? prefix + '.' : '';
if (typeof obj[k] === 'object' && obj[k] !== null && !Array.isArray(obj[k])) {
Object.assign(acc, this.flattenObject(obj[k], pre + k));
} else {
acc[pre + k] = obj[k];
}
return acc;
}, {});
}
private unflattenObject(obj: ObjectWithStringKeys): ObjectWithStringKeys {
const result: ObjectWithStringKeys = {};
for (const key in obj) {
const keys = key.split('.');
keys.reduce((r: ObjectWithStringKeys, k: string, i: number) => {
return r[k] = i === keys.length - 1 ? obj[key] : (r[k] || {});
}, result);
}
return result;
}
async shapeshift<T extends ObjectWithStringKeys, U extends ObjectWithStringKeys>(
sourceObj: T,
targetObj: U
): Promise<U> {
const flattenedSourceObj = this.flattenObject(sourceObj);
const flattenedTargetObj = this.flattenObject(targetObj);
const sourceKeys = Object.keys(flattenedSourceObj);
const targetKeys = Object.keys(flattenedTargetObj);
const sourceEmbeddings = await this.calculateEmbeddings(sourceKeys);
const targetEmbeddings = await this.calculateEmbeddings(targetKeys);
const flattenedResult: ObjectWithStringKeys = {};
for (let i = 0; i < sourceKeys.length; i++) {
const sourceKey = sourceKeys[i];
const sourceEmbedding = sourceEmbeddings[i];
const closestTargetIndex = this.findClosestMatch(sourceEmbedding, targetEmbeddings);
if (closestTargetIndex !== null) {
const closestTargetKey = targetKeys[closestTargetIndex];
flattenedResult[closestTargetKey] = flattenedSourceObj[sourceKey];
}
}
return this.unflattenObject(flattenedResult) as U;
}
}