const { BaseMemory } = require('./baseMemoryClass.js')
class LocalMemory extends BaseMemory {
/**
* @param {Function} embeddingProvider - Function to retrieve embeddings
*/
constructor(embeddingProvider) {
super()
/**
* @type {any[]}
*/
this.docs = []
this.embs = null
this.embedding_provider = embeddingProvider
}
/**
* Adds a document to memory along with its key (or uses the document as the key)
* Retrieves the embedding using the embeddingProvider
* Stores the embedding in the embs ndarray and the document in the docs array
* @param {string} doc - Document to add
* @param {string|null} [key] - Key for the document (optional)
*/
async add(doc, key = null) {
if (!key) {
key = doc
}
// @ts-ignore
const emb = await this.embedding_provider.get(key)
if (emb.data[0].embedding) {
const embedding = emb.data[0].embedding
if (this.embs === null) {
this.embs = [embedding]
} else {
this.embs.push(embedding)
}
this.docs.push(doc)
} else {
throw Error('Error getting embedding from provider')
}
}
/**
* This function takes a query and a value k, retrieves embeddings for the query, calculates scores
* for stored embeddings, and returns the top k documents based on the scores.
* @param {string} query - The query parameter is the input query for which the function will retrieve the top
* k most similar documents.
* @param {number} k - `k` is a positive integer representing the number of top results to return. The
* function will return the `k` documents with the highest scores based on their dot product with the
* provided query embedding.
* @returns {Promise<string[]|any[]>} The function `get` returns an array of documents that are most similar to the input
* query, based on their embeddings. The number of documents returned is determined by the value of
* the `k` parameter. If there are no embeddings stored or they are not in the expected format, an
* empty array is returned. If there is an error getting the embedding from the provider, an error is
* thrown.
*/
async get(query, k) {
if (this.embs === null || !Array.isArray(this.embs)) {
return []
}
// @ts-ignore
const emb = await this.embedding_provider.get(query)
if (emb.data[0].embedding && Array.isArray(emb.data[0].embedding)) {
const scores = this.embs.map((storedEmb) => {
if (Array.isArray(storedEmb) && storedEmb.length === 1536) {
return dotProduct(storedEmb, emb.data[0].embedding)
} else {
return 0 // or any default value if the storedEmb is invalid
}
})
const sortedIdxs = scores
.map((score, idx) => [score, idx])
.sort((a, b) => b[0] - a[0])
const uniqueIdxs = new Set()
const topKIdxs = []
for (const [score, idx] of sortedIdxs) {
if (uniqueIdxs.size >= k) {
break
}
if (!uniqueIdxs.has(idx)) {
uniqueIdxs.add(idx)
topKIdxs.push(idx)
}
}
const results = topKIdxs.map((idx) => this.docs[idx])
return results
} else {
throw Error('Error getting embedding from provider')
}
}
/**
* Serializes the embs array to an object representation
* @returns {Object|null} - Serialized embs array or null if embs is empty
*/
_serializeEmbs() {
if (this.embs === null || this.embs.length === 0) {
return null
}
const embSize = this.embs[0].length
const data = this.embs.flat()
return {
dtype: 'float32',
data: Array.from(data),
shape: [this.embs.length, embSize],
}
}
/**
* Returns the configuration of the LocalMemory instance as an object
* @returns {Object} - Configuration object
*/
// config() {
// const cfg = super.config();
// cfg.docs = this.docs;
// cfg.embs = this._serializeEmbs();
// cfg.embeddingProvider = this.embeddingProvider.config();
// return cfg;
// }
/**
* Creates a LocalMemory instance from a configuration object
* @param {Object} config - Configuration object
* @returns {LocalMemory} - LocalMemory instance
*/
// static fromConfig(config) {
// const provider = embeddingProviderFromConfig(config.embeddingProvider);
// const obj = new LocalMemory(provider);
// obj.docs = config.docs;
// const embs = config.embs;
// if (embs !== null) {
// const embSize = embs.shape[1];
// obj.embs = new Array(embs.shape[0]);
// for (let i = 0; i < embs.shape[0]; i++) {
// obj.embs[i] = embs.data.slice(i * embSize, (i + 1) * embSize);
// }
// }
// return obj;
// }
/**
* Clears the memory
*/
clear() {
this.docs = []
this.embs = null
}
}
/**
* Computes the dot product between two arrays
* @param {Array<number>} arr1 - First array
* @param {Array<number>} arr2 - Second array
* @returns {number} - Dot product
*/
function dotProduct(arr1, arr2) {
return arr1.reduce((acc, val, i) => acc + val * arr2[i], 0)
}
module.exports = {
LocalMemory,
}