Files
2026-03-12 19:17:00 +01:00

193 lines
5.1 KiB
Go

package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"github.com/paramah/ai_devs4/s01e03/internal/domain"
)
// LMStudioProvider implements domain.LLMProvider for local LM Studio
type LMStudioProvider struct {
baseURL string
model string
client *http.Client
verbose bool
}
// NewLMStudioProvider creates a new LM Studio provider
func NewLMStudioProvider(baseURL, model string, verbose bool) *LMStudioProvider {
return &LMStudioProvider{
baseURL: baseURL,
model: model,
client: &http.Client{},
verbose: verbose,
}
}
type lmStudioMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCallID string `json:"tool_call_id,omitempty"`
ToolCalls []lmStudioToolCall `json:"tool_calls,omitempty"`
}
type lmStudioToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type lmStudioRequest struct {
Model string `json:"model"`
Messages []lmStudioMessage `json:"messages"`
Tools []domain.Tool `json:"tools,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
}
type lmStudioResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []lmStudioToolCall `json:"tool_calls,omitempty"`
} `json:"message"`
} `json:"choices"`
Error json.RawMessage `json:"error,omitempty"`
}
// Complete sends a request to LM Studio with function calling support
func (p *LMStudioProvider) Complete(ctx context.Context, request domain.LLMRequest) (*domain.LLMResponse, error) {
// Convert domain messages to LM Studio format
messages := make([]lmStudioMessage, len(request.Messages))
for i, msg := range request.Messages {
// Convert tool calls if present
var toolCalls []lmStudioToolCall
if len(msg.ToolCalls) > 0 {
toolCalls = make([]lmStudioToolCall, len(msg.ToolCalls))
for j, tc := range msg.ToolCalls {
toolCalls[j] = lmStudioToolCall{
ID: tc.ID,
Type: tc.Type,
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
}
}
}
messages[i] = lmStudioMessage{
Role: msg.Role,
Content: msg.Content,
ToolCallID: msg.ToolCallID,
ToolCalls: toolCalls,
}
}
reqBody := lmStudioRequest{
Model: p.model,
Messages: messages,
Tools: request.Tools,
Temperature: 0.7,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshaling request: %w", err)
}
if p.verbose {
var prettyJSON bytes.Buffer
if err := json.Indent(&prettyJSON, jsonData, "", " "); err == nil {
log.Printf("\n========== LM STUDIO REQUEST ==========\nURL: %s/v1/chat/completions\nBody:\n%s\n=======================================\n", p.baseURL, prettyJSON.String())
}
}
url := p.baseURL + "/v1/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := p.client.Do(req)
if err != nil {
return nil, fmt.Errorf("sending request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
if p.verbose {
var prettyJSON bytes.Buffer
if err := json.Indent(&prettyJSON, body, "", " "); err == nil {
log.Printf("\n========== LM STUDIO RESPONSE ==========\nStatus: %d\nBody:\n%s\n========================================\n", resp.StatusCode, prettyJSON.String())
}
}
// Check HTTP status
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var apiResp lmStudioResponse
if err := json.Unmarshal(body, &apiResp); err != nil {
return nil, fmt.Errorf("unmarshaling response: %w\nResponse body: %s", err, string(body))
}
// Check for error in response
if len(apiResp.Error) > 0 {
var errStr string
if err := json.Unmarshal(apiResp.Error, &errStr); err == nil {
return nil, fmt.Errorf("API error: %s", errStr)
}
var errObj struct {
Message string `json:"message"`
}
if err := json.Unmarshal(apiResp.Error, &errObj); err == nil {
return nil, fmt.Errorf("API error: %s", errObj.Message)
}
return nil, fmt.Errorf("API error: %s", string(apiResp.Error))
}
if len(apiResp.Choices) == 0 {
return nil, fmt.Errorf("no choices in response")
}
// Convert tool calls to domain format
toolCalls := make([]domain.ToolCall, len(apiResp.Choices[0].Message.ToolCalls))
for i, tc := range apiResp.Choices[0].Message.ToolCalls {
toolCalls[i] = domain.ToolCall{
ID: tc.ID,
Type: tc.Type,
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
}
}
return &domain.LLMResponse{
Content: apiResp.Choices[0].Message.Content,
ToolCalls: toolCalls,
}, nil
}