initial commit
This commit is contained in:
192
internal/infrastructure/llm/lmstudio.go
Normal file
192
internal/infrastructure/llm/lmstudio.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/paramah/ai_devs4/s01e03/internal/domain"
|
||||
)
|
||||
|
||||
// LMStudioProvider implements domain.LLMProvider for local LM Studio
|
||||
type LMStudioProvider struct {
|
||||
baseURL string
|
||||
model string
|
||||
client *http.Client
|
||||
verbose bool
|
||||
}
|
||||
|
||||
// NewLMStudioProvider creates a new LM Studio provider
|
||||
func NewLMStudioProvider(baseURL, model string, verbose bool) *LMStudioProvider {
|
||||
return &LMStudioProvider{
|
||||
baseURL: baseURL,
|
||||
model: model,
|
||||
client: &http.Client{},
|
||||
verbose: verbose,
|
||||
}
|
||||
}
|
||||
|
||||
type lmStudioMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
ToolCalls []lmStudioToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type lmStudioToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type lmStudioRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []lmStudioMessage `json:"messages"`
|
||||
Tools []domain.Tool `json:"tools,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
||||
|
||||
type lmStudioResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []lmStudioToolCall `json:"tool_calls,omitempty"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Error json.RawMessage `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Complete sends a request to LM Studio with function calling support
|
||||
func (p *LMStudioProvider) Complete(ctx context.Context, request domain.LLMRequest) (*domain.LLMResponse, error) {
|
||||
// Convert domain messages to LM Studio format
|
||||
messages := make([]lmStudioMessage, len(request.Messages))
|
||||
for i, msg := range request.Messages {
|
||||
// Convert tool calls if present
|
||||
var toolCalls []lmStudioToolCall
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
toolCalls = make([]lmStudioToolCall, len(msg.ToolCalls))
|
||||
for j, tc := range msg.ToolCalls {
|
||||
toolCalls[j] = lmStudioToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages[i] = lmStudioMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := lmStudioRequest{
|
||||
Model: p.model,
|
||||
Messages: messages,
|
||||
Tools: request.Tools,
|
||||
Temperature: 0.7,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
if p.verbose {
|
||||
var prettyJSON bytes.Buffer
|
||||
if err := json.Indent(&prettyJSON, jsonData, "", " "); err == nil {
|
||||
log.Printf("\n========== LM STUDIO REQUEST ==========\nURL: %s/v1/chat/completions\nBody:\n%s\n=======================================\n", p.baseURL, prettyJSON.String())
|
||||
}
|
||||
}
|
||||
|
||||
url := p.baseURL + "/v1/chat/completions"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sending request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
if p.verbose {
|
||||
var prettyJSON bytes.Buffer
|
||||
if err := json.Indent(&prettyJSON, body, "", " "); err == nil {
|
||||
log.Printf("\n========== LM STUDIO RESPONSE ==========\nStatus: %d\nBody:\n%s\n========================================\n", resp.StatusCode, prettyJSON.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Check HTTP status
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var apiResp lmStudioResponse
|
||||
if err := json.Unmarshal(body, &apiResp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling response: %w\nResponse body: %s", err, string(body))
|
||||
}
|
||||
|
||||
// Check for error in response
|
||||
if len(apiResp.Error) > 0 {
|
||||
var errStr string
|
||||
if err := json.Unmarshal(apiResp.Error, &errStr); err == nil {
|
||||
return nil, fmt.Errorf("API error: %s", errStr)
|
||||
}
|
||||
var errObj struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(apiResp.Error, &errObj); err == nil {
|
||||
return nil, fmt.Errorf("API error: %s", errObj.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("API error: %s", string(apiResp.Error))
|
||||
}
|
||||
|
||||
if len(apiResp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("no choices in response")
|
||||
}
|
||||
|
||||
// Convert tool calls to domain format
|
||||
toolCalls := make([]domain.ToolCall, len(apiResp.Choices[0].Message.ToolCalls))
|
||||
for i, tc := range apiResp.Choices[0].Message.ToolCalls {
|
||||
toolCalls[i] = domain.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &domain.LLMResponse{
|
||||
Content: apiResp.Choices[0].Message.Content,
|
||||
ToolCalls: toolCalls,
|
||||
}, nil
|
||||
}
|
||||
178
internal/infrastructure/llm/openrouter.go
Normal file
178
internal/infrastructure/llm/openrouter.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/paramah/ai_devs4/s01e03/internal/domain"
|
||||
)
|
||||
|
||||
// OpenRouterProvider implements domain.LLMProvider for OpenRouter API
|
||||
type OpenRouterProvider struct {
|
||||
apiKey string
|
||||
model string
|
||||
baseURL string
|
||||
client *http.Client
|
||||
verbose bool
|
||||
}
|
||||
|
||||
// NewOpenRouterProvider creates a new OpenRouter provider
|
||||
func NewOpenRouterProvider(apiKey, model string, verbose bool) *OpenRouterProvider {
|
||||
return &OpenRouterProvider{
|
||||
apiKey: apiKey,
|
||||
model: model,
|
||||
baseURL: "https://openrouter.ai/api/v1/chat/completions",
|
||||
client: &http.Client{},
|
||||
verbose: verbose,
|
||||
}
|
||||
}
|
||||
|
||||
type openRouterMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
ToolCalls []openRouterToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type openRouterToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type openRouterRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openRouterMessage `json:"messages"`
|
||||
Tools []domain.Tool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type openRouterResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []openRouterToolCall `json:"tool_calls,omitempty"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Complete sends a request to OpenRouter API with function calling support
|
||||
func (p *OpenRouterProvider) Complete(ctx context.Context, request domain.LLMRequest) (*domain.LLMResponse, error) {
|
||||
// Convert domain messages to OpenRouter format
|
||||
messages := make([]openRouterMessage, len(request.Messages))
|
||||
for i, msg := range request.Messages {
|
||||
// Convert tool calls if present
|
||||
var toolCalls []openRouterToolCall
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
toolCalls = make([]openRouterToolCall, len(msg.ToolCalls))
|
||||
for j, tc := range msg.ToolCalls {
|
||||
toolCalls[j] = openRouterToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages[i] = openRouterMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := openRouterRequest{
|
||||
Model: p.model,
|
||||
Messages: messages,
|
||||
Tools: request.Tools,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
if p.verbose {
|
||||
var prettyJSON bytes.Buffer
|
||||
if err := json.Indent(&prettyJSON, jsonData, "", " "); err == nil {
|
||||
log.Printf("\n========== LLM REQUEST ==========\nURL: %s\nBody:\n%s\n================================\n", p.baseURL, prettyJSON.String())
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sending request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
if p.verbose {
|
||||
var prettyJSON bytes.Buffer
|
||||
if err := json.Indent(&prettyJSON, body, "", " "); err == nil {
|
||||
log.Printf("\n========== LLM RESPONSE ==========\nStatus: %d\nBody:\n%s\n==================================\n", resp.StatusCode, prettyJSON.String())
|
||||
}
|
||||
}
|
||||
|
||||
var apiResp openRouterResponse
|
||||
if err := json.Unmarshal(body, &apiResp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
if apiResp.Error != nil {
|
||||
return nil, fmt.Errorf("API error: %s", apiResp.Error.Message)
|
||||
}
|
||||
|
||||
if len(apiResp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("no choices in response")
|
||||
}
|
||||
|
||||
// Convert tool calls to domain format
|
||||
toolCalls := make([]domain.ToolCall, len(apiResp.Choices[0].Message.ToolCalls))
|
||||
for i, tc := range apiResp.Choices[0].Message.ToolCalls {
|
||||
toolCalls[i] = domain.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Function: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &domain.LLMResponse{
|
||||
Content: apiResp.Choices[0].Message.Content,
|
||||
ToolCalls: toolCalls,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user