mirror of
				https://github.com/zadam/trilium.git
				synced 2025-10-31 11:16:05 +08:00 
			
		
		
		
	set up more reasonable context window and dimension sizes
This commit is contained in:
		
							parent
							
								
									572a03a3f7
								
							
						
					
					
						commit
						697d348286
					
				
					 7 changed files with 787 additions and 174 deletions
				
			
		|  | @ -26,4 +26,5 @@ INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('aiSystemPr | ||||||
| INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingsDefaultProvider', 'openai', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));  | INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingsDefaultProvider', 'openai', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));  | ||||||
| INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('enableAutomaticIndexing', 'true', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')); | INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('enableAutomaticIndexing', 'true', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')); | ||||||
| INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingSimilarityThreshold', '0.65', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')); | INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingSimilarityThreshold', '0.65', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now')); | ||||||
| INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('maxNotesPerLlmQuery', '10', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));  | INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('maxNotesPerLlmQuery', '10', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));  | ||||||
|  | INSERT INTO options (name, value, isSynced, utcDateModified) VALUES ('embeddingBatchSize', '10', 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));  | ||||||
|  | @ -41,6 +41,37 @@ export const LLM_CONSTANTS = { | ||||||
|         } |         } | ||||||
|     }, |     }, | ||||||
| 
 | 
 | ||||||
|  |     // Model-specific embedding dimensions for Ollama models
 | ||||||
|  |     OLLAMA_MODEL_DIMENSIONS: { | ||||||
|  |         "llama3": 4096, | ||||||
|  |         "llama3.1": 4096, | ||||||
|  |         "mistral": 4096, | ||||||
|  |         "nomic": 768, | ||||||
|  |         "mxbai": 1024, | ||||||
|  |         "nomic-embed-text": 768, | ||||||
|  |         "mxbai-embed-large": 1024, | ||||||
|  |         "default": 384 | ||||||
|  |     }, | ||||||
|  | 
 | ||||||
|  |     // Model-specific context windows for Ollama models
 | ||||||
|  |     OLLAMA_MODEL_CONTEXT_WINDOWS: { | ||||||
|  |         "llama3": 8192, | ||||||
|  |         "mistral": 8192, | ||||||
|  |         "nomic": 32768, | ||||||
|  |         "mxbai": 32768, | ||||||
|  |         "nomic-embed-text": 32768, | ||||||
|  |         "mxbai-embed-large": 32768, | ||||||
|  |         "default": 4096 | ||||||
|  |     }, | ||||||
|  | 
 | ||||||
|  |     // Batch size configuration
 | ||||||
|  |     BATCH_SIZE: { | ||||||
|  |         OPENAI: 10,     // OpenAI can handle larger batches efficiently
 | ||||||
|  |         ANTHROPIC: 5,   // More conservative for Anthropic
 | ||||||
|  |         OLLAMA: 1,      // Ollama processes one at a time
 | ||||||
|  |         DEFAULT: 5      // Conservative default
 | ||||||
|  |     }, | ||||||
|  | 
 | ||||||
|     // Chunking parameters
 |     // Chunking parameters
 | ||||||
|     CHUNKING: { |     CHUNKING: { | ||||||
|         DEFAULT_SIZE: 1500, |         DEFAULT_SIZE: 1500, | ||||||
|  |  | ||||||
|  | @ -1,22 +1,212 @@ | ||||||
| import type { EmbeddingProvider, EmbeddingConfig, NoteEmbeddingContext } from './embeddings_interface.js'; | import type { EmbeddingProvider, EmbeddingConfig, NoteEmbeddingContext } from './embeddings_interface.js'; | ||||||
|  | import log from "../../log.js"; | ||||||
|  | import { LLM_CONSTANTS } from "../../../routes/api/llm.js"; | ||||||
|  | import options from "../../options.js"; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * Base class that implements common functionality for embedding providers |  * Base class that implements common functionality for embedding providers | ||||||
|  */ |  */ | ||||||
| export abstract class BaseEmbeddingProvider implements EmbeddingProvider { | export abstract class BaseEmbeddingProvider implements EmbeddingProvider { | ||||||
|     abstract name: string; |     name: string = "base"; | ||||||
|     protected config: EmbeddingConfig; |     protected config: EmbeddingConfig; | ||||||
|  |     protected apiKey?: string; | ||||||
|  |     protected baseUrl: string; | ||||||
|  |     protected modelInfoCache = new Map<string, any>(); | ||||||
| 
 | 
 | ||||||
|     constructor(config: EmbeddingConfig) { |     constructor(config: EmbeddingConfig) { | ||||||
|         this.config = config; |         this.config = config; | ||||||
|  |         this.apiKey = config.apiKey; | ||||||
|  |         this.baseUrl = config.baseUrl || ""; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     getConfig(): EmbeddingConfig { |     getConfig(): EmbeddingConfig { | ||||||
|         return this.config; |         return { ...this.config }; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     getDimension(): number { | ||||||
|  |         return this.config.dimension; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     async initialize(): Promise<void> { | ||||||
|  |         // Default implementation does nothing
 | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Generate embeddings for a single text | ||||||
|  |      */ | ||||||
|     abstract generateEmbeddings(text: string): Promise<Float32Array>; |     abstract generateEmbeddings(text: string): Promise<Float32Array>; | ||||||
|     abstract generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]>; | 
 | ||||||
|  |     /** | ||||||
|  |      * Get the appropriate batch size for this provider | ||||||
|  |      * Override in provider implementations if needed | ||||||
|  |      */ | ||||||
|  |     protected async getBatchSize(): Promise<number> { | ||||||
|  |         // Try to get the user-configured batch size
 | ||||||
|  |         let configuredBatchSize: number | null = null; | ||||||
|  | 
 | ||||||
|  |         try { | ||||||
|  |             const batchSizeStr = await options.getOption('embeddingBatchSize'); | ||||||
|  |             if (batchSizeStr) { | ||||||
|  |                 configuredBatchSize = parseInt(batchSizeStr, 10); | ||||||
|  |             } | ||||||
|  |         } catch (error) { | ||||||
|  |             log.error(`Error getting batch size from options: ${error}`); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // If user has configured a specific batch size, use that
 | ||||||
|  |         if (configuredBatchSize && !isNaN(configuredBatchSize) && configuredBatchSize > 0) { | ||||||
|  |             return configuredBatchSize; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Otherwise use the provider-specific default from constants
 | ||||||
|  |         return this.config.batchSize || | ||||||
|  |                LLM_CONSTANTS.BATCH_SIZE[this.name.toUpperCase() as keyof typeof LLM_CONSTANTS.BATCH_SIZE] || | ||||||
|  |                LLM_CONSTANTS.BATCH_SIZE.DEFAULT; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Process a batch of texts with adaptive handling | ||||||
|  |      * This method will try to process the batch and reduce batch size if encountering errors | ||||||
|  |      */ | ||||||
|  |     protected async processWithAdaptiveBatch<T>( | ||||||
|  |         items: T[], | ||||||
|  |         processFn: (batch: T[]) => Promise<any[]>, | ||||||
|  |         isBatchSizeError: (error: any) => boolean | ||||||
|  |     ): Promise<any[]> { | ||||||
|  |         const results: any[] = []; | ||||||
|  |         const failures: { index: number, error: string }[] = []; | ||||||
|  |         let currentBatchSize = await this.getBatchSize(); | ||||||
|  |         let lastError: Error | null = null; | ||||||
|  | 
 | ||||||
|  |         // Process items in batches
 | ||||||
|  |         for (let i = 0; i < items.length;) { | ||||||
|  |             const batch = items.slice(i, i + currentBatchSize); | ||||||
|  | 
 | ||||||
|  |             try { | ||||||
|  |                 // Process the current batch
 | ||||||
|  |                 const batchResults = await processFn(batch); | ||||||
|  |                 results.push(...batchResults); | ||||||
|  |                 i += batch.length; | ||||||
|  |             } | ||||||
|  |             catch (error: any) { | ||||||
|  |                 lastError = error; | ||||||
|  |                 const errorMessage = error.message || 'Unknown error'; | ||||||
|  | 
 | ||||||
|  |                 // Check if this is a batch size related error
 | ||||||
|  |                 if (isBatchSizeError(error) && currentBatchSize > 1) { | ||||||
|  |                     // Reduce batch size and retry
 | ||||||
|  |                     const newBatchSize = Math.max(1, Math.floor(currentBatchSize / 2)); | ||||||
|  |                     console.warn(`Batch size error detected, reducing batch size from ${currentBatchSize} to ${newBatchSize}: ${errorMessage}`); | ||||||
|  |                     currentBatchSize = newBatchSize; | ||||||
|  |                 } | ||||||
|  |                 else if (currentBatchSize === 1) { | ||||||
|  |                     // If we're already at batch size 1, we can't reduce further, so log the error and skip this item
 | ||||||
|  |                     log.error(`Error processing item at index ${i} with batch size 1: ${errorMessage}`); | ||||||
|  |                     failures.push({ index: i, error: errorMessage }); | ||||||
|  |                     i++; // Move to the next item
 | ||||||
|  |                 } | ||||||
|  |                 else { | ||||||
|  |                     // For other errors, retry with a smaller batch size as a precaution
 | ||||||
|  |                     const newBatchSize = Math.max(1, Math.floor(currentBatchSize / 2)); | ||||||
|  |                     console.warn(`Error processing batch, reducing batch size from ${currentBatchSize} to ${newBatchSize} as a precaution: ${errorMessage}`); | ||||||
|  |                     currentBatchSize = newBatchSize; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // If all items failed and we have a last error, throw it
 | ||||||
|  |         if (results.length === 0 && failures.length > 0 && lastError) { | ||||||
|  |             throw lastError; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // If some items failed but others succeeded, log the summary
 | ||||||
|  |         if (failures.length > 0) { | ||||||
|  |             console.warn(`Processed ${results.length} items successfully, but ${failures.length} items failed`); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         return results; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Detect if an error is related to batch size limits | ||||||
|  |      * Override in provider-specific implementations | ||||||
|  |      */ | ||||||
|  |     protected isBatchSizeError(error: any): boolean { | ||||||
|  |         const errorMessage = error?.message || ''; | ||||||
|  |         const batchSizeErrorPatterns = [ | ||||||
|  |             'batch size', 'too many items', 'too many inputs', | ||||||
|  |             'input too large', 'payload too large', 'context length', | ||||||
|  |             'token limit', 'rate limit', 'request too large' | ||||||
|  |         ]; | ||||||
|  | 
 | ||||||
|  |         return batchSizeErrorPatterns.some(pattern => | ||||||
|  |             errorMessage.toLowerCase().includes(pattern.toLowerCase()) | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Generate embeddings for multiple texts | ||||||
|  |      * Default implementation processes texts one by one | ||||||
|  |      */ | ||||||
|  |     async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> { | ||||||
|  |         if (texts.length === 0) { | ||||||
|  |             return []; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         try { | ||||||
|  |             return await this.processWithAdaptiveBatch( | ||||||
|  |                 texts, | ||||||
|  |                 async (batch) => { | ||||||
|  |                     const batchResults = await Promise.all( | ||||||
|  |                         batch.map(text => this.generateEmbeddings(text)) | ||||||
|  |                     ); | ||||||
|  |                     return batchResults; | ||||||
|  |                 }, | ||||||
|  |                 this.isBatchSizeError | ||||||
|  |             ); | ||||||
|  |         } | ||||||
|  |         catch (error: any) { | ||||||
|  |             const errorMessage = error.message || "Unknown error"; | ||||||
|  |             log.error(`Batch embedding error for provider ${this.name}: ${errorMessage}`); | ||||||
|  |             throw new Error(`${this.name} batch embedding error: ${errorMessage}`); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Generate embeddings for a note with its context | ||||||
|  |      */ | ||||||
|  |     async generateNoteEmbeddings(context: NoteEmbeddingContext): Promise<Float32Array> { | ||||||
|  |         const text = [context.title || "", context.content || ""].filter(Boolean).join(" "); | ||||||
|  |         return this.generateEmbeddings(text); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Generate embeddings for multiple notes with their contexts | ||||||
|  |      */ | ||||||
|  |     async generateBatchNoteEmbeddings(contexts: NoteEmbeddingContext[]): Promise<Float32Array[]> { | ||||||
|  |         if (contexts.length === 0) { | ||||||
|  |             return []; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         try { | ||||||
|  |             return await this.processWithAdaptiveBatch( | ||||||
|  |                 contexts, | ||||||
|  |                 async (batch) => { | ||||||
|  |                     const batchResults = await Promise.all( | ||||||
|  |                         batch.map(context => this.generateNoteEmbeddings(context)) | ||||||
|  |                     ); | ||||||
|  |                     return batchResults; | ||||||
|  |                 }, | ||||||
|  |                 this.isBatchSizeError | ||||||
|  |             ); | ||||||
|  |         } | ||||||
|  |         catch (error: any) { | ||||||
|  |             const errorMessage = error.message || "Unknown error"; | ||||||
|  |             log.error(`Batch note embedding error for provider ${this.name}: ${errorMessage}`); | ||||||
|  |             throw new Error(`${this.name} batch note embedding error: ${errorMessage}`); | ||||||
|  |         } | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * Cleans and normalizes text for embeddings by removing excessive whitespace |      * Cleans and normalizes text for embeddings by removing excessive whitespace | ||||||
|  | @ -157,20 +347,4 @@ export abstract class BaseEmbeddingProvider implements EmbeddingProvider { | ||||||
| 
 | 
 | ||||||
|         return result; |         return result; | ||||||
|     } |     } | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Default implementation that converts note context to text and generates embeddings |  | ||||||
|      */ |  | ||||||
|     async generateNoteEmbeddings(context: NoteEmbeddingContext): Promise<Float32Array> { |  | ||||||
|         const text = this.generateNoteContextText(context); |  | ||||||
|         return this.generateEmbeddings(text); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Default implementation that processes notes in batch |  | ||||||
|      */ |  | ||||||
|     async generateBatchNoteEmbeddings(contexts: NoteEmbeddingContext[]): Promise<Float32Array[]> { |  | ||||||
|         const texts = contexts.map(ctx => this.generateNoteContextText(ctx)); |  | ||||||
|         return this.generateBatchEmbeddings(texts); |  | ||||||
|     } |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -36,6 +36,14 @@ export interface NoteEmbeddingContext { | ||||||
|     templateTitles?: string[]; |     templateTitles?: string[]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /** | ||||||
|  |  * Information about an embedding model's capabilities | ||||||
|  |  */ | ||||||
|  | export interface EmbeddingModelInfo { | ||||||
|  |     dimension: number; | ||||||
|  |     contextWindow: number; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /** | /** | ||||||
|  * Configuration for how embeddings should be generated |  * Configuration for how embeddings should be generated | ||||||
|  */ |  */ | ||||||
|  | @ -46,6 +54,8 @@ export interface EmbeddingConfig { | ||||||
|     normalize?: boolean; |     normalize?: boolean; | ||||||
|     batchSize?: number; |     batchSize?: number; | ||||||
|     contextWindowSize?: number; |     contextWindowSize?: number; | ||||||
|  |     apiKey?: string; | ||||||
|  |     baseUrl?: string; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  |  | ||||||
|  | @ -1,25 +1,117 @@ | ||||||
| import { BaseEmbeddingProvider } from "../base_embeddings.js"; |  | ||||||
| import type { EmbeddingConfig } from "../embeddings_interface.js"; |  | ||||||
| import axios from "axios"; | import axios from "axios"; | ||||||
| import log from "../../../log.js"; | import log from "../../../log.js"; | ||||||
|  | import { BaseEmbeddingProvider } from "../base_embeddings.js"; | ||||||
|  | import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js"; | ||||||
|  | import { LLM_CONSTANTS } from "../../../../routes/api/llm.js"; | ||||||
| 
 | 
 | ||||||
| interface AnthropicEmbeddingConfig extends EmbeddingConfig { | // Anthropic model context window sizes - as of current API version
 | ||||||
|     apiKey: string; | const ANTHROPIC_MODEL_CONTEXT_WINDOWS: Record<string, number> = { | ||||||
|     baseUrl: string; |     "claude-3-opus-20240229": 200000, | ||||||
| } |     "claude-3-sonnet-20240229": 180000, | ||||||
|  |     "claude-3-haiku-20240307": 48000, | ||||||
|  |     "claude-2.1": 200000, | ||||||
|  |     "claude-2.0": 100000, | ||||||
|  |     "claude-instant-1.2": 100000, | ||||||
|  |     "default": 100000 | ||||||
|  | }; | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * Anthropic (Claude) embedding provider implementation |  * Anthropic embedding provider implementation | ||||||
|  */ |  */ | ||||||
| export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider { | export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|     name = "anthropic"; |     name = "anthropic"; | ||||||
|     private apiKey: string; |  | ||||||
|     private baseUrl: string; |  | ||||||
| 
 | 
 | ||||||
|     constructor(config: AnthropicEmbeddingConfig) { |     constructor(config: EmbeddingConfig) { | ||||||
|         super(config); |         super(config); | ||||||
|         this.apiKey = config.apiKey; |     } | ||||||
|         this.baseUrl = config.baseUrl; | 
 | ||||||
|  |     /** | ||||||
|  |      * Initialize the provider by detecting model capabilities | ||||||
|  |      */ | ||||||
|  |     async initialize(): Promise<void> { | ||||||
|  |         const modelName = this.config.model || "claude-3-haiku-20240307"; | ||||||
|  |         try { | ||||||
|  |             // Detect model capabilities
 | ||||||
|  |             const modelInfo = await this.getModelInfo(modelName); | ||||||
|  | 
 | ||||||
|  |             // Update the config dimension
 | ||||||
|  |             this.config.dimension = modelInfo.dimension; | ||||||
|  | 
 | ||||||
|  |             log.info(`Anthropic model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`); | ||||||
|  |         } catch (error: any) { | ||||||
|  |             log.error(`Error initializing Anthropic provider: ${error.message}`); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Try to determine Anthropic model capabilities | ||||||
|  |      * Note: Anthropic doesn't have a public endpoint for model metadata, so we use a combination | ||||||
|  |      * of known values and detection by test embeddings | ||||||
|  |      */ | ||||||
|  |     private async fetchModelCapabilities(modelName: string): Promise<EmbeddingModelInfo | null> { | ||||||
|  |         // Anthropic doesn't have a model info endpoint, but we can look up known context sizes
 | ||||||
|  |         // and detect embedding dimensions by making a test request
 | ||||||
|  | 
 | ||||||
|  |         try { | ||||||
|  |             // Get context window size from our local registry of known models
 | ||||||
|  |             const modelBase = Object.keys(ANTHROPIC_MODEL_CONTEXT_WINDOWS).find( | ||||||
|  |                 model => modelName.startsWith(model) | ||||||
|  |             ) || "default"; | ||||||
|  | 
 | ||||||
|  |             const contextWindow = ANTHROPIC_MODEL_CONTEXT_WINDOWS[modelBase]; | ||||||
|  | 
 | ||||||
|  |             // For embedding dimension, we'll return null and let getModelInfo detect it
 | ||||||
|  |             return { | ||||||
|  |                 dimension: 0, // Will be detected by test embedding
 | ||||||
|  |                 contextWindow | ||||||
|  |             }; | ||||||
|  |         } catch (error) { | ||||||
|  |             log.info(`Could not determine capabilities for Anthropic model ${modelName}: ${error}`); | ||||||
|  |             return null; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Get model information including embedding dimensions | ||||||
|  |      */ | ||||||
|  |     async getModelInfo(modelName: string): Promise<EmbeddingModelInfo> { | ||||||
|  |         // Check cache first
 | ||||||
|  |         if (this.modelInfoCache.has(modelName)) { | ||||||
|  |             return this.modelInfoCache.get(modelName); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Try to determine model capabilities
 | ||||||
|  |         const capabilities = await this.fetchModelCapabilities(modelName); | ||||||
|  |         const contextWindow = capabilities?.contextWindow || LLM_CONSTANTS.CONTEXT_WINDOW.ANTHROPIC; | ||||||
|  | 
 | ||||||
|  |         // For Anthropic, we need to detect embedding dimension with a test call
 | ||||||
|  |         try { | ||||||
|  |             // Detect dimension with a test embedding
 | ||||||
|  |             const testEmbedding = await this.generateEmbeddings("Test"); | ||||||
|  |             const dimension = testEmbedding.length; | ||||||
|  | 
 | ||||||
|  |             const modelInfo: EmbeddingModelInfo = { | ||||||
|  |                 dimension, | ||||||
|  |                 contextWindow | ||||||
|  |             }; | ||||||
|  | 
 | ||||||
|  |             this.modelInfoCache.set(modelName, modelInfo); | ||||||
|  |             this.config.dimension = dimension; | ||||||
|  | 
 | ||||||
|  |             log.info(`Detected Anthropic model ${modelName} with dimension ${dimension} (context: ${contextWindow})`); | ||||||
|  |             return modelInfo; | ||||||
|  |         } catch (error: any) { | ||||||
|  |             // If detection fails, use defaults
 | ||||||
|  |             const dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.ANTHROPIC.DEFAULT; | ||||||
|  | 
 | ||||||
|  |             log.info(`Using default parameters for Anthropic model ${modelName}: dimension ${dimension}, context ${contextWindow}`); | ||||||
|  | 
 | ||||||
|  |             const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; | ||||||
|  |             this.modelInfoCache.set(modelName, modelInfo); | ||||||
|  |             this.config.dimension = dimension; | ||||||
|  | 
 | ||||||
|  |             return modelInfo; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|  | @ -27,11 +119,23 @@ export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|      */ |      */ | ||||||
|     async generateEmbeddings(text: string): Promise<Float32Array> { |     async generateEmbeddings(text: string): Promise<Float32Array> { | ||||||
|         try { |         try { | ||||||
|  |             if (!text.trim()) { | ||||||
|  |                 return new Float32Array(this.config.dimension); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Get model info to check context window
 | ||||||
|  |             const modelName = this.config.model || "claude-3-haiku-20240307"; | ||||||
|  |             const modelInfo = await this.getModelInfo(modelName); | ||||||
|  | 
 | ||||||
|  |             // Trim text if it might exceed context window (rough character estimate)
 | ||||||
|  |             const charLimit = modelInfo.contextWindow * 4; // Rough estimate: avg 4 chars per token
 | ||||||
|  |             const trimmedText = text.length > charLimit ? text.substring(0, charLimit) : text; | ||||||
|  | 
 | ||||||
|             const response = await axios.post( |             const response = await axios.post( | ||||||
|                 `${this.baseUrl}/embeddings`, |                 `${this.baseUrl}/embeddings`, | ||||||
|                 { |                 { | ||||||
|                     model: this.config.model || "claude-3-haiku-20240307", |                     model: modelName, | ||||||
|                     input: text, |                     input: trimmedText, | ||||||
|                     encoding_format: "float" |                     encoding_format: "float" | ||||||
|                 }, |                 }, | ||||||
|                 { |                 { | ||||||
|  | @ -44,8 +148,7 @@ export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|             ); |             ); | ||||||
| 
 | 
 | ||||||
|             if (response.data && response.data.embedding) { |             if (response.data && response.data.embedding) { | ||||||
|                 const embedding = response.data.embedding; |                 return new Float32Array(response.data.embedding); | ||||||
|                 return new Float32Array(embedding); |  | ||||||
|             } else { |             } else { | ||||||
|                 throw new Error("Unexpected response structure from Anthropic API"); |                 throw new Error("Unexpected response structure from Anthropic API"); | ||||||
|             } |             } | ||||||
|  | @ -56,23 +159,60 @@ export class AnthropicEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     /** | ||||||
|  |      * More specific implementation of batch size error detection for Anthropic | ||||||
|  |      */ | ||||||
|  |     protected isBatchSizeError(error: any): boolean { | ||||||
|  |         const errorMessage = error?.message || error?.response?.data?.error?.message || ''; | ||||||
|  |         const anthropicBatchSizeErrorPatterns = [ | ||||||
|  |             'batch size', 'too many inputs', 'context length exceeded', | ||||||
|  |             'token limit', 'rate limit', 'limit exceeded', | ||||||
|  |             'too long', 'request too large', 'content too large' | ||||||
|  |         ]; | ||||||
|  | 
 | ||||||
|  |         return anthropicBatchSizeErrorPatterns.some(pattern => | ||||||
|  |             errorMessage.toLowerCase().includes(pattern.toLowerCase()) | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /** |     /** | ||||||
|      * Generate embeddings for multiple texts in a single batch |      * Generate embeddings for multiple texts in a single batch | ||||||
|      * |      * | ||||||
|      * Note: Anthropic doesn't currently support batch embedding, so we process each text individually |      * Note: Anthropic doesn't currently support batch embedding, so we process each text individually | ||||||
|  |      * but using the adaptive batch processor to handle errors and retries | ||||||
|      */ |      */ | ||||||
|     async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> { |     async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> { | ||||||
|         if (texts.length === 0) { |         if (texts.length === 0) { | ||||||
|             return []; |             return []; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         const results: Float32Array[] = []; |         try { | ||||||
|  |             return await this.processWithAdaptiveBatch( | ||||||
|  |                 texts, | ||||||
|  |                 async (batch) => { | ||||||
|  |                     const results: Float32Array[] = []; | ||||||
| 
 | 
 | ||||||
|         for (const text of texts) { |                     // For Anthropic, we have to process one at a time
 | ||||||
|             const embedding = await this.generateEmbeddings(text); |                     for (const text of batch) { | ||||||
|             results.push(embedding); |                         // Skip empty texts
 | ||||||
|  |                         if (!text.trim()) { | ||||||
|  |                             results.push(new Float32Array(this.config.dimension)); | ||||||
|  |                             continue; | ||||||
|  |                         } | ||||||
|  | 
 | ||||||
|  |                         const embedding = await this.generateEmbeddings(text); | ||||||
|  |                         results.push(embedding); | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                     return results; | ||||||
|  |                 }, | ||||||
|  |                 this.isBatchSizeError | ||||||
|  |             ); | ||||||
|  |         } | ||||||
|  |         catch (error: any) { | ||||||
|  |             const errorMessage = error.message || "Unknown error"; | ||||||
|  |             log.error(`Anthropic batch embedding error: ${errorMessage}`); | ||||||
|  |             throw new Error(`Anthropic batch embedding error: ${errorMessage}`); | ||||||
|         } |         } | ||||||
| 
 |  | ||||||
|         return results; |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,30 +1,17 @@ | ||||||
| import { BaseEmbeddingProvider } from "../base_embeddings.js"; |  | ||||||
| import type { EmbeddingConfig } from "../embeddings_interface.js"; |  | ||||||
| import axios from "axios"; | import axios from "axios"; | ||||||
| import log from "../../../log.js"; | import log from "../../../log.js"; | ||||||
| 
 | import { BaseEmbeddingProvider } from "../base_embeddings.js"; | ||||||
| interface OllamaEmbeddingConfig extends EmbeddingConfig { | import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js"; | ||||||
|     baseUrl: string; | import { LLM_CONSTANTS } from "../../../../routes/api/llm.js"; | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Model-specific embedding dimensions
 |  | ||||||
| interface EmbeddingModelInfo { |  | ||||||
|     dimension: number; |  | ||||||
|     contextWindow: number; |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * Ollama embedding provider implementation |  * Ollama embedding provider implementation | ||||||
|  */ |  */ | ||||||
| export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { | export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|     name = "ollama"; |     name = "ollama"; | ||||||
|     private baseUrl: string; |  | ||||||
|     // Cache for model dimensions to avoid repeated API calls
 |  | ||||||
|     private modelInfoCache = new Map<string, EmbeddingModelInfo>(); |  | ||||||
| 
 | 
 | ||||||
|     constructor(config: OllamaEmbeddingConfig) { |     constructor(config: EmbeddingConfig) { | ||||||
|         super(config); |         super(config); | ||||||
|         this.baseUrl = config.baseUrl; |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|  | @ -33,97 +20,148 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|     async initialize(): Promise<void> { |     async initialize(): Promise<void> { | ||||||
|         const modelName = this.config.model || "llama3"; |         const modelName = this.config.model || "llama3"; | ||||||
|         try { |         try { | ||||||
|             await this.getModelInfo(modelName); |             // Detect model capabilities
 | ||||||
|             log.info(`Ollama embedding provider initialized with model ${modelName}`); |             const modelInfo = await this.getModelInfo(modelName); | ||||||
|  | 
 | ||||||
|  |             // Update the config dimension
 | ||||||
|  |             this.config.dimension = modelInfo.dimension; | ||||||
|  | 
 | ||||||
|  |             log.info(`Ollama model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`); | ||||||
|         } catch (error: any) { |         } catch (error: any) { | ||||||
|             log.error(`Failed to initialize Ollama embedding provider: ${error.message}`); |             log.error(`Error initializing Ollama provider: ${error.message}`); | ||||||
|             // Still continue with default dimensions
 |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * Get model information including embedding dimensions |      * Fetch detailed model information from Ollama API | ||||||
|  |      * @param modelName The name of the model to fetch information for | ||||||
|      */ |      */ | ||||||
|     async getModelInfo(modelName: string): Promise<EmbeddingModelInfo> { |     private async fetchModelCapabilities(modelName: string): Promise<EmbeddingModelInfo | null> { | ||||||
|         // Check cache first
 |  | ||||||
|         if (this.modelInfoCache.has(modelName)) { |  | ||||||
|             return this.modelInfoCache.get(modelName)!; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // Default dimensions for common embedding models
 |  | ||||||
|         const defaultDimensions: Record<string, number> = { |  | ||||||
|             "nomic-embed-text": 768, |  | ||||||
|             "mxbai-embed-large": 1024, |  | ||||||
|             "llama3": 4096, |  | ||||||
|             "all-minilm": 384, |  | ||||||
|             "default": 4096 |  | ||||||
|         }; |  | ||||||
| 
 |  | ||||||
|         // Default context windows
 |  | ||||||
|         const defaultContextWindows: Record<string, number> = { |  | ||||||
|             "nomic-embed-text": 8192, |  | ||||||
|             "mxbai-embed-large": 8192, |  | ||||||
|             "llama3": 8192, |  | ||||||
|             "all-minilm": 4096, |  | ||||||
|             "default": 4096 |  | ||||||
|         }; |  | ||||||
| 
 |  | ||||||
|         try { |         try { | ||||||
|             // Try to detect if this is an embedding model
 |             // First try the /api/show endpoint which has detailed model information
 | ||||||
|             const testResponse = await axios.post( |             const showResponse = await axios.get( | ||||||
|                 `${this.baseUrl}/api/embeddings`, |                 `${this.baseUrl}/api/show`, | ||||||
|                 { |  | ||||||
|                     model: modelName, |  | ||||||
|                     prompt: "Test" |  | ||||||
|                 }, |  | ||||||
|                 { |                 { | ||||||
|  |                     params: { name: modelName }, | ||||||
|                     headers: { "Content-Type": "application/json" }, |                     headers: { "Content-Type": "application/json" }, | ||||||
|                     timeout: 10000 |                     timeout: 10000 | ||||||
|                 } |                 } | ||||||
|             ); |             ); | ||||||
| 
 | 
 | ||||||
|             let dimension = 0; |             if (showResponse.data && showResponse.data.parameters) { | ||||||
|             let contextWindow = 0; |                 const params = showResponse.data.parameters; | ||||||
|  |                 // Extract context length from parameters (different models might use different parameter names)
 | ||||||
|  |                 const contextWindow = params.context_length || | ||||||
|  |                                      params.num_ctx || | ||||||
|  |                                      params.context_window || | ||||||
|  |                                      (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record<string, number>).default; | ||||||
| 
 | 
 | ||||||
|             if (testResponse.data && Array.isArray(testResponse.data.embedding)) { |                 // Some models might provide embedding dimensions
 | ||||||
|                 dimension = testResponse.data.embedding.length; |                 const embeddingDimension = params.embedding_length || params.dim || null; | ||||||
| 
 | 
 | ||||||
|                 // Set context window based on model name if we have it
 |                 log.info(`Fetched Ollama model info from API for ${modelName}: context window ${contextWindow}`); | ||||||
|                 const baseModelName = modelName.split(':')[0]; |  | ||||||
|                 contextWindow = defaultContextWindows[baseModelName] || defaultContextWindows.default; |  | ||||||
| 
 | 
 | ||||||
|                 log.info(`Detected Ollama model ${modelName} with dimension ${dimension}`); |                 return { | ||||||
|             } else { |                     dimension: embeddingDimension || 0, // We'll detect this separately if not provided
 | ||||||
|                 throw new Error("Could not detect embedding dimensions"); |                     contextWindow: contextWindow | ||||||
|  |                 }; | ||||||
|             } |             } | ||||||
|  |         } catch (error: any) { | ||||||
|  |             log.info(`Could not fetch model info from Ollama show API: ${error.message}. Will try embedding test.`); | ||||||
|  |             // We'll fall back to embedding test if this fails
 | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         return null; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Get model information by probing the API | ||||||
|  |      */ | ||||||
|  |     async getModelInfo(modelName: string): Promise<EmbeddingModelInfo> { | ||||||
|  |         // Check cache first
 | ||||||
|  |         if (this.modelInfoCache.has(modelName)) { | ||||||
|  |             return this.modelInfoCache.get(modelName); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Try to fetch model capabilities from API
 | ||||||
|  |         const apiModelInfo = await this.fetchModelCapabilities(modelName); | ||||||
|  |         if (apiModelInfo) { | ||||||
|  |             // If we have context window but no embedding dimension, we need to detect the dimension
 | ||||||
|  |             if (apiModelInfo.contextWindow && !apiModelInfo.dimension) { | ||||||
|  |                 try { | ||||||
|  |                     // Detect dimension with a test embedding
 | ||||||
|  |                     const dimension = await this.detectEmbeddingDimension(modelName); | ||||||
|  |                     apiModelInfo.dimension = dimension; | ||||||
|  |                 } catch (error) { | ||||||
|  |                     // If dimension detection fails, fall back to defaults
 | ||||||
|  |                     const baseModelName = modelName.split(':')[0]; | ||||||
|  |                     apiModelInfo.dimension = (LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record<string, number>)[baseModelName] || | ||||||
|  |                                            (LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record<string, number>).default; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // Cache and return the API-provided info
 | ||||||
|  |             this.modelInfoCache.set(modelName, apiModelInfo); | ||||||
|  |             this.config.dimension = apiModelInfo.dimension; | ||||||
|  |             return apiModelInfo; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // If API info fetch fails, fall back to test embedding
 | ||||||
|  |         try { | ||||||
|  |             const dimension = await this.detectEmbeddingDimension(modelName); | ||||||
|  |             const baseModelName = modelName.split(':')[0]; | ||||||
|  |             const contextWindow = (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record<string, number>)[baseModelName] || | ||||||
|  |                                 (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record<string, number>).default; | ||||||
| 
 | 
 | ||||||
|             const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; |             const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; | ||||||
|             this.modelInfoCache.set(modelName, modelInfo); |             this.modelInfoCache.set(modelName, modelInfo); | ||||||
| 
 |  | ||||||
|             // Update the provider config dimension
 |  | ||||||
|             this.config.dimension = dimension; |             this.config.dimension = dimension; | ||||||
| 
 | 
 | ||||||
|  |             log.info(`Detected Ollama model ${modelName} with dimension ${dimension} (context: ${contextWindow})`); | ||||||
|             return modelInfo; |             return modelInfo; | ||||||
|         } catch (error: any) { |         } catch (error: any) { | ||||||
|             log.error(`Error detecting Ollama model capabilities: ${error.message}`); |             log.error(`Error detecting Ollama model capabilities: ${error.message}`); | ||||||
| 
 | 
 | ||||||
|             // If detection fails, use defaults based on model name
 |             // If all detection fails, use defaults based on model name
 | ||||||
|             const baseModelName = modelName.split(':')[0]; |             const baseModelName = modelName.split(':')[0]; | ||||||
|             const dimension = defaultDimensions[baseModelName] || defaultDimensions.default; |             const dimension = (LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record<string, number>)[baseModelName] || | ||||||
|             const contextWindow = defaultContextWindows[baseModelName] || defaultContextWindows.default; |                             (LLM_CONSTANTS.OLLAMA_MODEL_DIMENSIONS as Record<string, number>).default; | ||||||
|  |             const contextWindow = (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record<string, number>)[baseModelName] || | ||||||
|  |                                 (LLM_CONSTANTS.OLLAMA_MODEL_CONTEXT_WINDOWS as Record<string, number>).default; | ||||||
| 
 | 
 | ||||||
|             log.info(`Using default dimension ${dimension} for model ${modelName}`); |             log.info(`Using default parameters for model ${modelName}: dimension ${dimension}, context ${contextWindow}`); | ||||||
| 
 | 
 | ||||||
|             const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; |             const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; | ||||||
|             this.modelInfoCache.set(modelName, modelInfo); |             this.modelInfoCache.set(modelName, modelInfo); | ||||||
| 
 |  | ||||||
|             // Update the provider config dimension
 |  | ||||||
|             this.config.dimension = dimension; |             this.config.dimension = dimension; | ||||||
| 
 | 
 | ||||||
|             return modelInfo; |             return modelInfo; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     /** | ||||||
|  |      * Detect embedding dimension by making a test API call | ||||||
|  |      */ | ||||||
|  |     private async detectEmbeddingDimension(modelName: string): Promise<number> { | ||||||
|  |         const testResponse = await axios.post( | ||||||
|  |             `${this.baseUrl}/api/embeddings`, | ||||||
|  |             { | ||||||
|  |                 model: modelName, | ||||||
|  |                 prompt: "Test" | ||||||
|  |             }, | ||||||
|  |             { | ||||||
|  |                 headers: { "Content-Type": "application/json" }, | ||||||
|  |                 timeout: 10000 | ||||||
|  |             } | ||||||
|  |         ); | ||||||
|  | 
 | ||||||
|  |         if (testResponse.data && Array.isArray(testResponse.data.embedding)) { | ||||||
|  |             return testResponse.data.embedding.length; | ||||||
|  |         } else { | ||||||
|  |             throw new Error("Could not detect embedding dimensions"); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /** |     /** | ||||||
|      * Get the current embedding dimension |      * Get the current embedding dimension | ||||||
|      */ |      */ | ||||||
|  | @ -136,6 +174,10 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|      */ |      */ | ||||||
|     async generateEmbeddings(text: string): Promise<Float32Array> { |     async generateEmbeddings(text: string): Promise<Float32Array> { | ||||||
|         try { |         try { | ||||||
|  |             if (!text.trim()) { | ||||||
|  |                 return new Float32Array(this.config.dimension); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|             const modelName = this.config.model || "llama3"; |             const modelName = this.config.model || "llama3"; | ||||||
| 
 | 
 | ||||||
|             // Ensure we have model info
 |             // Ensure we have model info
 | ||||||
|  | @ -173,29 +215,60 @@ export class OllamaEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     /** | ||||||
|  |      * More specific implementation of batch size error detection for Ollama | ||||||
|  |      */ | ||||||
|  |     protected isBatchSizeError(error: any): boolean { | ||||||
|  |         const errorMessage = error?.message || ''; | ||||||
|  |         const ollamaBatchSizeErrorPatterns = [ | ||||||
|  |             'context length', 'token limit', 'out of memory', | ||||||
|  |             'too large', 'overloaded', 'prompt too long', | ||||||
|  |             'too many tokens', 'maximum size' | ||||||
|  |         ]; | ||||||
|  | 
 | ||||||
|  |         return ollamaBatchSizeErrorPatterns.some(pattern => | ||||||
|  |             errorMessage.toLowerCase().includes(pattern.toLowerCase()) | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /** |     /** | ||||||
|      * Generate embeddings for multiple texts |      * Generate embeddings for multiple texts | ||||||
|      * |      * | ||||||
|      * Note: Ollama API doesn't support batch embedding, so we process them sequentially |      * Note: Ollama API doesn't support batch embedding, so we process them sequentially | ||||||
|  |      * but using the adaptive batch processor to handle rate limits and retries | ||||||
|      */ |      */ | ||||||
|     async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> { |     async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> { | ||||||
|         if (texts.length === 0) { |         if (texts.length === 0) { | ||||||
|             return []; |             return []; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         const results: Float32Array[] = []; |         try { | ||||||
|  |             return await this.processWithAdaptiveBatch( | ||||||
|  |                 texts, | ||||||
|  |                 async (batch) => { | ||||||
|  |                     const results: Float32Array[] = []; | ||||||
| 
 | 
 | ||||||
|         for (const text of texts) { |                     // For Ollama, we have to process one at a time
 | ||||||
|             try { |                     for (const text of batch) { | ||||||
|                 const embedding = await this.generateEmbeddings(text); |                         // Skip empty texts
 | ||||||
|                 results.push(embedding); |                         if (!text.trim()) { | ||||||
|             } catch (error: any) { |                             results.push(new Float32Array(this.config.dimension)); | ||||||
|                 const errorMessage = error.response?.data?.error?.message || error.message || "Unknown error"; |                             continue; | ||||||
|                 log.error(`Ollama batch embedding error: ${errorMessage}`); |                         } | ||||||
|                 throw new Error(`Ollama batch embedding error: ${errorMessage}`); | 
 | ||||||
|             } |                         const embedding = await this.generateEmbeddings(text); | ||||||
|  |                         results.push(embedding); | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                     return results; | ||||||
|  |                 }, | ||||||
|  |                 this.isBatchSizeError | ||||||
|  |             ); | ||||||
|  |         } | ||||||
|  |         catch (error: any) { | ||||||
|  |             const errorMessage = error.message || "Unknown error"; | ||||||
|  |             log.error(`Ollama batch embedding error: ${errorMessage}`); | ||||||
|  |             throw new Error(`Ollama batch embedding error: ${errorMessage}`); | ||||||
|         } |         } | ||||||
| 
 |  | ||||||
|         return results; |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,25 +1,165 @@ | ||||||
| import { BaseEmbeddingProvider } from "../base_embeddings.js"; |  | ||||||
| import type { EmbeddingConfig } from "../embeddings_interface.js"; |  | ||||||
| import axios from "axios"; | import axios from "axios"; | ||||||
| import log from "../../../log.js"; | import log from "../../../log.js"; | ||||||
| 
 | import { BaseEmbeddingProvider } from "../base_embeddings.js"; | ||||||
| interface OpenAIEmbeddingConfig extends EmbeddingConfig { | import type { EmbeddingConfig, EmbeddingModelInfo } from "../embeddings_interface.js"; | ||||||
|     apiKey: string; | import { LLM_CONSTANTS } from "../../../../routes/api/llm.js"; | ||||||
|     baseUrl: string; |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| /** | /** | ||||||
|  * OpenAI embedding provider implementation |  * OpenAI embedding provider implementation | ||||||
|  */ |  */ | ||||||
| export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { | export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|     name = "openai"; |     name = "openai"; | ||||||
|     private apiKey: string; |  | ||||||
|     private baseUrl: string; |  | ||||||
| 
 | 
 | ||||||
|     constructor(config: OpenAIEmbeddingConfig) { |     constructor(config: EmbeddingConfig) { | ||||||
|         super(config); |         super(config); | ||||||
|         this.apiKey = config.apiKey; |     } | ||||||
|         this.baseUrl = config.baseUrl; | 
 | ||||||
|  |     /** | ||||||
|  |      * Initialize the provider by detecting model capabilities | ||||||
|  |      */ | ||||||
|  |     async initialize(): Promise<void> { | ||||||
|  |         const modelName = this.config.model || "text-embedding-3-small"; | ||||||
|  |         try { | ||||||
|  |             // Detect model capabilities
 | ||||||
|  |             const modelInfo = await this.getModelInfo(modelName); | ||||||
|  | 
 | ||||||
|  |             // Update the config dimension
 | ||||||
|  |             this.config.dimension = modelInfo.dimension; | ||||||
|  | 
 | ||||||
|  |             log.info(`OpenAI model ${modelName} initialized with dimension ${this.config.dimension} and context window ${modelInfo.contextWindow}`); | ||||||
|  |         } catch (error: any) { | ||||||
|  |             log.error(`Error initializing OpenAI provider: ${error.message}`); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Fetch model information from the OpenAI API | ||||||
|  |      */ | ||||||
|  |     private async fetchModelCapabilities(modelName: string): Promise<EmbeddingModelInfo | null> { | ||||||
|  |         if (!this.apiKey) { | ||||||
|  |             return null; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         try { | ||||||
|  |             // First try to get model details from the models API
 | ||||||
|  |             const response = await axios.get( | ||||||
|  |                 `${this.baseUrl}/models/${modelName}`, | ||||||
|  |                 { | ||||||
|  |                     headers: { | ||||||
|  |                         "Authorization": `Bearer ${this.apiKey}`, | ||||||
|  |                         "Content-Type": "application/json" | ||||||
|  |                     }, | ||||||
|  |                     timeout: 10000 | ||||||
|  |                 } | ||||||
|  |             ); | ||||||
|  | 
 | ||||||
|  |             if (response.data) { | ||||||
|  |                 // Different model families may have different ways of exposing context window
 | ||||||
|  |                 let contextWindow = 0; | ||||||
|  |                 let dimension = 0; | ||||||
|  | 
 | ||||||
|  |                 // Extract context window if available
 | ||||||
|  |                 if (response.data.context_window) { | ||||||
|  |                     contextWindow = response.data.context_window; | ||||||
|  |                 } else if (response.data.limits && response.data.limits.context_window) { | ||||||
|  |                     contextWindow = response.data.limits.context_window; | ||||||
|  |                 } else if (response.data.limits && response.data.limits.context_length) { | ||||||
|  |                     contextWindow = response.data.limits.context_length; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 // Extract embedding dimensions if available
 | ||||||
|  |                 if (response.data.dimensions) { | ||||||
|  |                     dimension = response.data.dimensions; | ||||||
|  |                 } else if (response.data.embedding_dimension) { | ||||||
|  |                     dimension = response.data.embedding_dimension; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 // If we didn't get all the info, use defaults for missing values
 | ||||||
|  |                 if (!contextWindow) { | ||||||
|  |                     // Set default context window based on model name patterns
 | ||||||
|  |                     if (modelName.includes('ada') || modelName.includes('embedding-ada')) { | ||||||
|  |                         contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI; | ||||||
|  |                     } else if (modelName.includes('davinci')) { | ||||||
|  |                         contextWindow = 8192; | ||||||
|  |                     } else if (modelName.includes('embedding-3')) { | ||||||
|  |                         contextWindow = 8191; | ||||||
|  |                     } else { | ||||||
|  |                         contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 if (!dimension) { | ||||||
|  |                     // Set default dimensions based on model name patterns
 | ||||||
|  |                     if (modelName.includes('ada') || modelName.includes('embedding-ada')) { | ||||||
|  |                         dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.OPENAI.ADA; | ||||||
|  |                     } else if (modelName.includes('embedding-3-small')) { | ||||||
|  |                         dimension = 1536; | ||||||
|  |                     } else if (modelName.includes('embedding-3-large')) { | ||||||
|  |                         dimension = 3072; | ||||||
|  |                     } else { | ||||||
|  |                         dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.OPENAI.DEFAULT; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 log.info(`Fetched OpenAI model info for ${modelName}: context window ${contextWindow}, dimension ${dimension}`); | ||||||
|  | 
 | ||||||
|  |                 return { | ||||||
|  |                     dimension, | ||||||
|  |                     contextWindow | ||||||
|  |                 }; | ||||||
|  |             } | ||||||
|  |         } catch (error: any) { | ||||||
|  |             log.info(`Could not fetch model info from OpenAI API: ${error.message}. Will try embedding test.`); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         return null; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Get model information including embedding dimensions | ||||||
|  |      */ | ||||||
|  |     async getModelInfo(modelName: string): Promise<EmbeddingModelInfo> { | ||||||
|  |         // Check cache first
 | ||||||
|  |         if (this.modelInfoCache.has(modelName)) { | ||||||
|  |             return this.modelInfoCache.get(modelName); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Try to fetch model capabilities from API
 | ||||||
|  |         const apiModelInfo = await this.fetchModelCapabilities(modelName); | ||||||
|  |         if (apiModelInfo) { | ||||||
|  |             // Cache and return the API-provided info
 | ||||||
|  |             this.modelInfoCache.set(modelName, apiModelInfo); | ||||||
|  |             this.config.dimension = apiModelInfo.dimension; | ||||||
|  |             return apiModelInfo; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // If API info fetch fails, try to detect embedding dimension with a test call
 | ||||||
|  |         try { | ||||||
|  |             const testEmbedding = await this.generateEmbeddings("Test"); | ||||||
|  |             const dimension = testEmbedding.length; | ||||||
|  | 
 | ||||||
|  |             // Use default context window
 | ||||||
|  |             let contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI; | ||||||
|  | 
 | ||||||
|  |             const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; | ||||||
|  |             this.modelInfoCache.set(modelName, modelInfo); | ||||||
|  |             this.config.dimension = dimension; | ||||||
|  | 
 | ||||||
|  |             log.info(`Detected OpenAI model ${modelName} with dimension ${dimension} (context: ${contextWindow})`); | ||||||
|  |             return modelInfo; | ||||||
|  |         } catch (error: any) { | ||||||
|  |             // If detection fails, use defaults
 | ||||||
|  |             const dimension = LLM_CONSTANTS.EMBEDDING_DIMENSIONS.OPENAI.DEFAULT; | ||||||
|  |             const contextWindow = LLM_CONSTANTS.CONTEXT_WINDOW.OPENAI; | ||||||
|  | 
 | ||||||
|  |             log.info(`Using default parameters for OpenAI model ${modelName}: dimension ${dimension}, context ${contextWindow}`); | ||||||
|  | 
 | ||||||
|  |             const modelInfo: EmbeddingModelInfo = { dimension, contextWindow }; | ||||||
|  |             this.modelInfoCache.set(modelName, modelInfo); | ||||||
|  |             this.config.dimension = dimension; | ||||||
|  | 
 | ||||||
|  |             return modelInfo; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|  | @ -27,6 +167,10 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|      */ |      */ | ||||||
|     async generateEmbeddings(text: string): Promise<Float32Array> { |     async generateEmbeddings(text: string): Promise<Float32Array> { | ||||||
|         try { |         try { | ||||||
|  |             if (!text.trim()) { | ||||||
|  |                 return new Float32Array(this.config.dimension); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|             const response = await axios.post( |             const response = await axios.post( | ||||||
|                 `${this.baseUrl}/embeddings`, |                 `${this.baseUrl}/embeddings`, | ||||||
|                 { |                 { | ||||||
|  | @ -43,8 +187,7 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|             ); |             ); | ||||||
| 
 | 
 | ||||||
|             if (response.data && response.data.data && response.data.data[0] && response.data.data[0].embedding) { |             if (response.data && response.data.data && response.data.data[0] && response.data.data[0].embedding) { | ||||||
|                 const embedding = response.data.data[0].embedding; |                 return new Float32Array(response.data.data[0].embedding); | ||||||
|                 return new Float32Array(embedding); |  | ||||||
|             } else { |             } else { | ||||||
|                 throw new Error("Unexpected response structure from OpenAI API"); |                 throw new Error("Unexpected response structure from OpenAI API"); | ||||||
|             } |             } | ||||||
|  | @ -55,53 +198,94 @@ export class OpenAIEmbeddingProvider extends BaseEmbeddingProvider { | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     /** | ||||||
|  |      * More specific implementation of batch size error detection for OpenAI | ||||||
|  |      */ | ||||||
|  |     protected isBatchSizeError(error: any): boolean { | ||||||
|  |         const errorMessage = error?.message || error?.response?.data?.error?.message || ''; | ||||||
|  |         const openAIBatchSizeErrorPatterns = [ | ||||||
|  |             'batch size', 'too many inputs', 'context length exceeded', | ||||||
|  |             'maximum context length', 'token limit', 'rate limit exceeded', | ||||||
|  |             'tokens in the messages', 'reduce the length', 'too long' | ||||||
|  |         ]; | ||||||
|  | 
 | ||||||
|  |         return openAIBatchSizeErrorPatterns.some(pattern => | ||||||
|  |             errorMessage.toLowerCase().includes(pattern.toLowerCase()) | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Custom implementation for batched OpenAI embeddings | ||||||
|  |      */ | ||||||
|  |     async generateBatchEmbeddingsWithAPI(texts: string[]): Promise<Float32Array[]> { | ||||||
|  |         if (texts.length === 0) { | ||||||
|  |             return []; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         const response = await axios.post( | ||||||
|  |             `${this.baseUrl}/embeddings`, | ||||||
|  |             { | ||||||
|  |                 input: texts, | ||||||
|  |                 model: this.config.model || "text-embedding-3-small", | ||||||
|  |                 encoding_format: "float" | ||||||
|  |             }, | ||||||
|  |             { | ||||||
|  |                 headers: { | ||||||
|  |                     "Content-Type": "application/json", | ||||||
|  |                     "Authorization": `Bearer ${this.apiKey}` | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         ); | ||||||
|  | 
 | ||||||
|  |         if (response.data && response.data.data) { | ||||||
|  |             // Sort the embeddings by index to ensure they match the input order
 | ||||||
|  |             const sortedEmbeddings = response.data.data | ||||||
|  |                 .sort((a: any, b: any) => a.index - b.index) | ||||||
|  |                 .map((item: any) => new Float32Array(item.embedding)); | ||||||
|  | 
 | ||||||
|  |             return sortedEmbeddings; | ||||||
|  |         } else { | ||||||
|  |             throw new Error("Unexpected response structure from OpenAI API"); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /** |     /** | ||||||
|      * Generate embeddings for multiple texts in a single batch |      * Generate embeddings for multiple texts in a single batch | ||||||
|  |      * OpenAI API supports batch embedding, so we implement a custom version | ||||||
|      */ |      */ | ||||||
|     async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> { |     async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> { | ||||||
|         if (texts.length === 0) { |         if (texts.length === 0) { | ||||||
|             return []; |             return []; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         const batchSize = this.config.batchSize || 10; |         try { | ||||||
|         const results: Float32Array[] = []; |             return await this.processWithAdaptiveBatch( | ||||||
|  |                 texts, | ||||||
|  |                 async (batch) => { | ||||||
|  |                     // Filter out empty texts and use the API batch functionality
 | ||||||
|  |                     const filteredBatch = batch.filter(text => text.trim().length > 0); | ||||||
| 
 | 
 | ||||||
|         // Process in batches to avoid API limits
 |                     if (filteredBatch.length === 0) { | ||||||
|         for (let i = 0; i < texts.length; i += batchSize) { |                         // If all texts are empty after filtering, return empty embeddings
 | ||||||
|             const batch = texts.slice(i, i + batchSize); |                         return batch.map(() => new Float32Array(this.config.dimension)); | ||||||
|             try { |  | ||||||
|                 const response = await axios.post( |  | ||||||
|                     `${this.baseUrl}/embeddings`, |  | ||||||
|                     { |  | ||||||
|                         input: batch, |  | ||||||
|                         model: this.config.model || "text-embedding-3-small", |  | ||||||
|                         encoding_format: "float" |  | ||||||
|                     }, |  | ||||||
|                     { |  | ||||||
|                         headers: { |  | ||||||
|                             "Content-Type": "application/json", |  | ||||||
|                             "Authorization": `Bearer ${this.apiKey}` |  | ||||||
|                         } |  | ||||||
|                     } |                     } | ||||||
|                 ); |  | ||||||
| 
 | 
 | ||||||
|                 if (response.data && response.data.data) { |                     if (filteredBatch.length === 1) { | ||||||
|                     // Sort the embeddings by index to ensure they match the input order
 |                         // If only one text, use the single embedding endpoint
 | ||||||
|                     const sortedEmbeddings = response.data.data |                         const embedding = await this.generateEmbeddings(filteredBatch[0]); | ||||||
|                         .sort((a: any, b: any) => a.index - b.index) |                         return [embedding]; | ||||||
|                         .map((item: any) => new Float32Array(item.embedding)); |                     } | ||||||
| 
 | 
 | ||||||
|                     results.push(...sortedEmbeddings); |                     // Use the batch API endpoint
 | ||||||
|                 } else { |                     return this.generateBatchEmbeddingsWithAPI(filteredBatch); | ||||||
|                     throw new Error("Unexpected response structure from OpenAI API"); |                 }, | ||||||
|                 } |                 this.isBatchSizeError | ||||||
|             } catch (error: any) { |             ); | ||||||
|                 const errorMessage = error.response?.data?.error?.message || error.message || "Unknown error"; |         } | ||||||
|                 log.error(`OpenAI batch embedding error: ${errorMessage}`); |         catch (error: any) { | ||||||
|                 throw new Error(`OpenAI batch embedding error: ${errorMessage}`); |             const errorMessage = error.message || "Unknown error"; | ||||||
|             } |             log.error(`OpenAI batch embedding error: ${errorMessage}`); | ||||||
|  |             throw new Error(`OpenAI batch embedding error: ${errorMessage}`); | ||||||
|         } |         } | ||||||
| 
 |  | ||||||
|         return results; |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		
		Reference in a new issue