package llm import ( "bytes" "context" "encoding/json" "fmt" "io" "log" "net/http" "github.com/paramah/ai_devs4/s01e03/internal/domain" ) // LMStudioProvider implements domain.LLMProvider for local LM Studio type LMStudioProvider struct { baseURL string model string client *http.Client verbose bool } // NewLMStudioProvider creates a new LM Studio provider func NewLMStudioProvider(baseURL, model string, verbose bool) *LMStudioProvider { return &LMStudioProvider{ baseURL: baseURL, model: model, client: &http.Client{}, verbose: verbose, } } type lmStudioMessage struct { Role string `json:"role"` Content string `json:"content"` ToolCallID string `json:"tool_call_id,omitempty"` ToolCalls []lmStudioToolCall `json:"tool_calls,omitempty"` } type lmStudioToolCall struct { ID string `json:"id"` Type string `json:"type"` Function struct { Name string `json:"name"` Arguments string `json:"arguments"` } `json:"function"` } type lmStudioRequest struct { Model string `json:"model"` Messages []lmStudioMessage `json:"messages"` Tools []domain.Tool `json:"tools,omitempty"` Temperature float64 `json:"temperature,omitempty"` } type lmStudioResponse struct { Choices []struct { Message struct { Content string `json:"content"` ToolCalls []lmStudioToolCall `json:"tool_calls,omitempty"` } `json:"message"` } `json:"choices"` Error json.RawMessage `json:"error,omitempty"` } // Complete sends a request to LM Studio with function calling support func (p *LMStudioProvider) Complete(ctx context.Context, request domain.LLMRequest) (*domain.LLMResponse, error) { // Convert domain messages to LM Studio format messages := make([]lmStudioMessage, len(request.Messages)) for i, msg := range request.Messages { // Convert tool calls if present var toolCalls []lmStudioToolCall if len(msg.ToolCalls) > 0 { toolCalls = make([]lmStudioToolCall, len(msg.ToolCalls)) for j, tc := range msg.ToolCalls { toolCalls[j] = lmStudioToolCall{ ID: tc.ID, Type: tc.Type, Function: struct { Name string `json:"name"` Arguments string `json:"arguments"` }{ Name: tc.Function.Name, Arguments: tc.Function.Arguments, }, } } } messages[i] = lmStudioMessage{ Role: msg.Role, Content: msg.Content, ToolCallID: msg.ToolCallID, ToolCalls: toolCalls, } } reqBody := lmStudioRequest{ Model: p.model, Messages: messages, Tools: request.Tools, Temperature: 0.7, } jsonData, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("marshaling request: %w", err) } if p.verbose { var prettyJSON bytes.Buffer if err := json.Indent(&prettyJSON, jsonData, "", " "); err == nil { log.Printf("\n========== LM STUDIO REQUEST ==========\nURL: %s/v1/chat/completions\nBody:\n%s\n=======================================\n", p.baseURL, prettyJSON.String()) } } url := p.baseURL + "/v1/chat/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) if err != nil { return nil, fmt.Errorf("creating request: %w", err) } req.Header.Set("Content-Type", "application/json") resp, err := p.client.Do(req) if err != nil { return nil, fmt.Errorf("sending request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("reading response: %w", err) } if p.verbose { var prettyJSON bytes.Buffer if err := json.Indent(&prettyJSON, body, "", " "); err == nil { log.Printf("\n========== LM STUDIO RESPONSE ==========\nStatus: %d\nBody:\n%s\n========================================\n", resp.StatusCode, prettyJSON.String()) } } // Check HTTP status if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) } var apiResp lmStudioResponse if err := json.Unmarshal(body, &apiResp); err != nil { return nil, fmt.Errorf("unmarshaling response: %w\nResponse body: %s", err, string(body)) } // Check for error in response if len(apiResp.Error) > 0 { var errStr string if err := json.Unmarshal(apiResp.Error, &errStr); err == nil { return nil, fmt.Errorf("API error: %s", errStr) } var errObj struct { Message string `json:"message"` } if err := json.Unmarshal(apiResp.Error, &errObj); err == nil { return nil, fmt.Errorf("API error: %s", errObj.Message) } return nil, fmt.Errorf("API error: %s", string(apiResp.Error)) } if len(apiResp.Choices) == 0 { return nil, fmt.Errorf("no choices in response") } // Convert tool calls to domain format toolCalls := make([]domain.ToolCall, len(apiResp.Choices[0].Message.ToolCalls)) for i, tc := range apiResp.Choices[0].Message.ToolCalls { toolCalls[i] = domain.ToolCall{ ID: tc.ID, Type: tc.Type, Function: struct { Name string `json:"name"` Arguments string `json:"arguments"` }{ Name: tc.Function.Name, Arguments: tc.Function.Arguments, }, } } return &domain.LLMResponse{ Content: apiResp.Choices[0].Message.Content, ToolCalls: toolCalls, }, nil }