package usecase import ( "context" "encoding/json" "fmt" "log" "os" "path/filepath" "github.com/paramah/ai_devs4/s01e02/internal/domain" ) // AgentProcessorUseCase handles the processing of persons using LLM agent type AgentProcessorUseCase struct { personRepo domain.PersonRepository locationRepo domain.LocationRepository apiClient domain.APIClient llmProvider domain.LLMProvider apiKey string outputDir string } // NewAgentProcessorUseCase creates a new use case instance func NewAgentProcessorUseCase( personRepo domain.PersonRepository, locationRepo domain.LocationRepository, apiClient domain.APIClient, llmProvider domain.LLMProvider, apiKey string, outputDir string, ) *AgentProcessorUseCase { return &AgentProcessorUseCase{ personRepo: personRepo, locationRepo: locationRepo, apiClient: apiClient, llmProvider: llmProvider, apiKey: apiKey, outputDir: outputDir, } } // Execute processes all persons using LLM agent func (uc *AgentProcessorUseCase) Execute(ctx context.Context, inputFile string) error { // Load persons from file log.Printf("Loading persons from: %s", inputFile) persons, err := uc.personRepo.LoadPersons(ctx, inputFile) if err != nil { return fmt.Errorf("loading persons: %w", err) } log.Printf("Loaded %d persons", len(persons)) // Load power plant locations log.Printf("Loading power plant locations...") locations, err := uc.locationRepo.LoadLocations(ctx, uc.apiKey) if err != nil { return fmt.Errorf("loading locations: %w", err) } log.Printf("Loaded %d power plant locations", len(locations)) // Process each person with agent for i, person := range persons { log.Printf("\n[%d/%d] Processing: %s %s", i+1, len(persons), person.Name, person.Surname) if err := uc.processPerson(ctx, person, locations); err != nil { log.Printf("Error processing %s %s: %v", person.Name, person.Surname, err) continue } } log.Printf("\nProcessing completed!") return nil } // processPerson uses LLM agent to gather data for a person func (uc *AgentProcessorUseCase) processPerson(ctx context.Context, person domain.Person, powerPlants []domain.Location) error { tools := domain.GetToolDefinitions() // Initial system message systemPrompt := fmt.Sprintf(`You are an agent that gathers information about people. For the person %s %s (born: %d), you need to: 1. Call get_location to get their current location coordinates 2. Call get_access_level to get their access level (remember: birth_year parameter must be only the year as integer, e.g., %d) After gathering the data, respond with "DONE".`, person.Name, person.Surname, person.Born, person.Born) messages := []domain.LLMMessage{ { Role: "system", Content: systemPrompt, }, { Role: "user", Content: fmt.Sprintf("Please gather information for %s %s.", person.Name, person.Surname), }, } maxIterations := 10 var personLocation *domain.PersonLocation for iteration := 0; iteration < maxIterations; iteration++ { log.Printf(" [Iteration %d] Calling LLM...", iteration+1) resp, err := uc.llmProvider.Chat(ctx, domain.LLMRequest{ Messages: messages, Tools: tools, ToolChoice: "auto", Temperature: 0.0, }) if err != nil { return fmt.Errorf("LLM chat error: %w", err) } messages = append(messages, resp.Message) // Check if LLM wants to call functions if len(resp.Message.ToolCalls) > 0 { log.Printf(" → LLM requested %d tool call(s)", len(resp.Message.ToolCalls)) for _, toolCall := range resp.Message.ToolCalls { result, loc, err := uc.executeToolCall(ctx, person, toolCall) if err != nil { return fmt.Errorf("executing tool call: %w", err) } // Store person location if we got it from get_location if loc != nil { personLocation = loc } messages = append(messages, domain.LLMMessage{ Role: "tool", Content: result, ToolCallID: toolCall.ID, }) } } else if resp.FinishReason == "stop" { log.Printf(" ✓ Agent completed gathering data") break } } // Calculate distances if we have location if personLocation != nil && len(powerPlants) > 0 { log.Printf(" → Calculating distances to power plants...") closest := domain.FindClosestLocation(*personLocation, powerPlants) if closest != nil { log.Printf(" ✓ Closest power plant: %s (%.2f km)", closest.Location, closest.DistanceKm) // Save distance result distanceFile := filepath.Join(uc.outputDir, "distances", fmt.Sprintf("%s_%s.json", person.Name, person.Surname)) distanceData, _ := json.MarshalIndent(closest, "", " ") os.WriteFile(distanceFile, distanceData, 0644) } } return nil } // executeToolCall executes a tool call from the LLM func (uc *AgentProcessorUseCase) executeToolCall(ctx context.Context, person domain.Person, toolCall domain.ToolCall) (string, *domain.PersonLocation, error) { log.Printf(" → Executing: %s", toolCall.Function.Name) var args map[string]interface{} if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { return "", nil, fmt.Errorf("parsing arguments: %w", err) } switch toolCall.Function.Name { case "get_location": name, _ := args["name"].(string) surname, _ := args["surname"].(string) req := domain.LocationRequest{ APIKey: uc.apiKey, Name: name, Surname: surname, } response, err := uc.apiClient.GetLocation(ctx, req) if err != nil { return fmt.Sprintf("Error: %v", err), nil, nil } // Save response fileName := fmt.Sprintf("%s_%s.json", name, surname) filePath := filepath.Join(uc.outputDir, "locations", fileName) os.WriteFile(filePath, response, 0644) log.Printf(" ✓ Saved to: %s", filePath) // Parse location to get coordinates (API returns array of locations) var locationData []map[string]interface{} if err := json.Unmarshal(response, &locationData); err == nil && len(locationData) > 0 { // Take first location from the array firstLoc := locationData[0] if lat, ok := firstLoc["latitude"].(float64); ok { if lon, ok := firstLoc["longitude"].(float64); ok { personLoc := &domain.PersonLocation{ Name: name, Surname: surname, Latitude: lat, Longitude: lon, } return string(response), personLoc, nil } } } return string(response), nil, nil case "get_access_level": name, _ := args["name"].(string) surname, _ := args["surname"].(string) birthYear, _ := args["birth_year"].(float64) req := domain.AccessLevelRequest{ APIKey: uc.apiKey, Name: name, Surname: surname, BirthYear: int(birthYear), } response, err := uc.apiClient.GetAccessLevel(ctx, req) if err != nil { return fmt.Sprintf("Error: %v", err), nil, nil } // Save response fileName := fmt.Sprintf("%s_%s.json", name, surname) filePath := filepath.Join(uc.outputDir, "accesslevel", fileName) os.WriteFile(filePath, response, 0644) log.Printf(" ✓ Saved to: %s", filePath) return string(response), nil, nil default: return fmt.Sprintf("Unknown function: %s", toolCall.Function.Name), nil, nil } }