mirror of
https://github.com/usememos/memos.git
synced 2025-12-18 22:59:24 +08:00
- Replace fmt.Errorf with errors.Wrapf/New for forbidigo compliance - Add missing periods to comments for godot compliance - Fix goimports formatting issues in multiple files - Remove unused struct field writes in tests - Replace unused method receivers with underscore - Optimize single-case switch to if-else statement Signed-off-by: ChaoLiu <chaoliu719@gmail.com>
252 lines
6.3 KiB
Go
252 lines
6.3 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/openai/openai-go/v2"
|
|
"github.com/openai/openai-go/v2/option"
|
|
"github.com/pkg/errors"
|
|
|
|
storepb "github.com/usememos/memos/proto/gen/store"
|
|
)
|
|
|
|
// Common AI errors.
|
|
var (
|
|
ErrConfigIncomplete = errors.New("AI configuration incomplete - missing BaseURL, APIKey, or Model")
|
|
ErrEmptyRequest = errors.New("chat request cannot be empty")
|
|
ErrInvalidMessage = errors.New("message role must be 'system', 'user', or 'assistant'")
|
|
ErrEmptyContent = errors.New("message content cannot be empty")
|
|
ErrAPICallFailed = errors.New("AI API call failed")
|
|
ErrEmptyResponse = errors.New("received empty response from AI")
|
|
ErrNoChoices = errors.New("AI returned no response choices")
|
|
)
|
|
|
|
// Config holds AI configuration
|
|
type Config struct {
|
|
Enabled bool
|
|
BaseURL string
|
|
APIKey string
|
|
Model string
|
|
TimeoutSeconds int
|
|
}
|
|
|
|
// LoadConfigFromEnv loads AI configuration from environment variables
|
|
func LoadConfigFromEnv() *Config {
|
|
timeoutSeconds := 10 // default timeout
|
|
if timeoutStr := os.Getenv("AI_TIMEOUT_SECONDS"); timeoutStr != "" {
|
|
if timeout, err := strconv.Atoi(timeoutStr); err == nil && timeout > 0 {
|
|
timeoutSeconds = timeout
|
|
}
|
|
}
|
|
|
|
config := &Config{
|
|
BaseURL: os.Getenv("AI_BASE_URL"),
|
|
APIKey: os.Getenv("AI_API_KEY"),
|
|
Model: os.Getenv("AI_MODEL"),
|
|
TimeoutSeconds: timeoutSeconds,
|
|
}
|
|
|
|
// Enable AI if all required fields are provided (API Key is optional for local services like Ollama)
|
|
config.Enabled = config.BaseURL != "" && config.Model != ""
|
|
|
|
return config
|
|
}
|
|
|
|
// LoadConfigFromDatabase loads AI configuration from database settings
|
|
func LoadConfigFromDatabase(aiSetting *storepb.WorkspaceAISetting) *Config {
|
|
if aiSetting == nil {
|
|
return &Config{Enabled: false}
|
|
}
|
|
|
|
timeoutSeconds := int(aiSetting.TimeoutSeconds)
|
|
if timeoutSeconds <= 0 {
|
|
timeoutSeconds = 10 // default timeout
|
|
}
|
|
|
|
return &Config{
|
|
Enabled: aiSetting.EnableAi,
|
|
BaseURL: aiSetting.BaseUrl,
|
|
APIKey: aiSetting.ApiKey,
|
|
Model: aiSetting.Model,
|
|
TimeoutSeconds: timeoutSeconds,
|
|
}
|
|
}
|
|
|
|
// MergeWithEnv merges database config with environment variables
|
|
// Environment variables take precedence if they are set
|
|
func (c *Config) MergeWithEnv() *Config {
|
|
envConfig := LoadConfigFromEnv()
|
|
|
|
// Start with current config
|
|
merged := &Config{
|
|
Enabled: c.Enabled,
|
|
BaseURL: c.BaseURL,
|
|
APIKey: c.APIKey,
|
|
Model: c.Model,
|
|
TimeoutSeconds: c.TimeoutSeconds,
|
|
}
|
|
|
|
// Override with env vars if they are set
|
|
if envConfig.BaseURL != "" {
|
|
merged.BaseURL = envConfig.BaseURL
|
|
}
|
|
if envConfig.APIKey != "" {
|
|
merged.APIKey = envConfig.APIKey
|
|
}
|
|
if envConfig.Model != "" {
|
|
merged.Model = envConfig.Model
|
|
}
|
|
if os.Getenv("AI_TIMEOUT_SECONDS") != "" {
|
|
merged.TimeoutSeconds = envConfig.TimeoutSeconds
|
|
}
|
|
|
|
// Enable if all required fields are present (API Key is optional for local services like Ollama)
|
|
merged.Enabled = merged.BaseURL != "" && merged.Model != ""
|
|
|
|
return merged
|
|
}
|
|
|
|
// IsConfigured returns true if AI is properly configured
|
|
func (c *Config) IsConfigured() bool {
|
|
return c.Enabled && c.BaseURL != "" && c.Model != ""
|
|
}
|
|
|
|
// Client wraps OpenAI client with convenience methods
|
|
type Client struct {
|
|
client openai.Client
|
|
config *Config
|
|
}
|
|
|
|
// NewClient creates a new AI client
|
|
func NewClient(config *Config) (*Client, error) {
|
|
if config == nil {
|
|
return nil, errors.New("config cannot be nil")
|
|
}
|
|
|
|
if !config.IsConfigured() {
|
|
return nil, ErrConfigIncomplete
|
|
}
|
|
|
|
var client openai.Client
|
|
if config.BaseURL != "" && config.BaseURL != "https://api.openai.com/v1" {
|
|
client = openai.NewClient(
|
|
option.WithAPIKey(config.APIKey),
|
|
option.WithBaseURL(config.BaseURL),
|
|
)
|
|
} else {
|
|
client = openai.NewClient(
|
|
option.WithAPIKey(config.APIKey),
|
|
)
|
|
}
|
|
|
|
return &Client{
|
|
client: client,
|
|
config: config,
|
|
}, nil
|
|
}
|
|
|
|
// ChatRequest represents a chat completion request
|
|
type ChatRequest struct {
|
|
Messages []Message
|
|
MaxTokens int
|
|
Temperature float64
|
|
Timeout time.Duration
|
|
}
|
|
|
|
// Message represents a chat message
|
|
type Message struct {
|
|
Role string // "system", "user", "assistant"
|
|
Content string
|
|
}
|
|
|
|
// ChatResponse represents a chat completion response
|
|
type ChatResponse struct {
|
|
Content string
|
|
}
|
|
|
|
// Chat performs a chat completion
|
|
func (c *Client) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
|
if req == nil {
|
|
return nil, ErrEmptyRequest
|
|
}
|
|
|
|
if len(req.Messages) == 0 {
|
|
return nil, ErrEmptyRequest
|
|
}
|
|
|
|
// Validate messages
|
|
for i, msg := range req.Messages {
|
|
if msg.Role != "system" && msg.Role != "user" && msg.Role != "assistant" {
|
|
return nil, errors.Wrapf(ErrInvalidMessage, "message %d", i)
|
|
}
|
|
if strings.TrimSpace(msg.Content) == "" {
|
|
return nil, errors.Wrapf(ErrEmptyContent, "message %d", i)
|
|
}
|
|
}
|
|
|
|
// Set defaults
|
|
if req.MaxTokens == 0 {
|
|
req.MaxTokens = 8192
|
|
}
|
|
if req.Temperature == 0 {
|
|
req.Temperature = 0.3
|
|
}
|
|
if req.Timeout == 0 {
|
|
// Use timeout from config if available
|
|
if c.config.TimeoutSeconds > 0 {
|
|
req.Timeout = time.Duration(c.config.TimeoutSeconds) * time.Second
|
|
} else {
|
|
req.Timeout = 10 * time.Second
|
|
}
|
|
}
|
|
|
|
model := c.config.Model
|
|
if model == "" {
|
|
model = "gpt-4o" // Default model
|
|
}
|
|
|
|
// Convert messages
|
|
messages := make([]openai.ChatCompletionMessageParamUnion, 0, len(req.Messages))
|
|
for _, msg := range req.Messages {
|
|
switch msg.Role {
|
|
case "system":
|
|
messages = append(messages, openai.SystemMessage(msg.Content))
|
|
case "user":
|
|
messages = append(messages, openai.UserMessage(msg.Content))
|
|
case "assistant":
|
|
messages = append(messages, openai.AssistantMessage(msg.Content))
|
|
}
|
|
}
|
|
|
|
// Create timeout context
|
|
timeoutCtx, cancel := context.WithTimeout(ctx, req.Timeout)
|
|
defer cancel()
|
|
|
|
// Make API call
|
|
completion, err := c.client.Chat.Completions.New(timeoutCtx, openai.ChatCompletionNewParams{
|
|
Messages: messages,
|
|
Model: model,
|
|
MaxTokens: openai.Int(int64(req.MaxTokens)),
|
|
Temperature: openai.Float(req.Temperature),
|
|
})
|
|
if err != nil {
|
|
return nil, errors.Wrapf(ErrAPICallFailed, "%v", err)
|
|
}
|
|
|
|
if len(completion.Choices) == 0 {
|
|
return nil, ErrNoChoices
|
|
}
|
|
|
|
response := strings.TrimSpace(completion.Choices[0].Message.Content)
|
|
if response == "" {
|
|
return nil, ErrEmptyResponse
|
|
}
|
|
|
|
return &ChatResponse{
|
|
Content: response,
|
|
}, nil
|
|
}
|