diff --git a/agent/tools.go b/agent/tools.go index 58c07ad..3173fc2 100644 --- a/agent/tools.go +++ b/agent/tools.go @@ -7,82 +7,222 @@ import ( "strings" ) -// Tool-Ergebnis das dem LLM zurückgegeben wird -type ToolResult struct { - Success bool - Output string +// ─── Tool Registry ──────────────────────────────────────── + +type Tool struct { + Name string + Description string + Usage string } -// Parst Tool-Calls aus der LLM-Antwort -// Erwartetes Format: -// TOOL:READ_FILE:path/to/file -// TOOL:WRITE_FILE:path/to/file:<<>> -// TOOL:LIST_FILES:. +var Registry = []Tool{ + { + Name: "READ_FILE", + Description: "Liest den Inhalt einer Datei", + Usage: "TOOL:READ_FILE:pfad/zur/datei", + }, + { + Name: "WRITE_FILE", + Description: "Schreibt Inhalt in eine Datei (mehrzeilig möglich)", + Usage: `TOOL:WRITE_FILE:pfad/zur/datei +<<< +dateiinhalt hier +>>>`, + }, + { + Name: "LIST_FILES", + Description: "Listet alle Dateien in einem Verzeichnis", + Usage: "TOOL:LIST_FILES:pfad", + }, +} -func ExecuteTools(response string, workDir string) (string, bool) { +// BuildToolPrompt generiert den Tool-Abschnitt für den System-Prompt +func BuildToolPrompt() string { + var sb strings.Builder + sb.WriteString("Du hast folgende Tools zur Verfügung:\n\n") + for _, t := range Registry { + sb.WriteString(fmt.Sprintf("### %s\n", t.Name)) + sb.WriteString(fmt.Sprintf("Beschreibung: %s\n", t.Description)) + sb.WriteString(fmt.Sprintf("Verwendung:\n%s\n\n", t.Usage)) + } + return sb.String() +} + +// ─── Tool Parsing ───────────────────────────────────────── + +type toolCall struct { + name string + path string + content string // nur für WRITE_FILE +} + +// parseToolCalls extrahiert alle Tool-Calls aus einer LLM-Antwort. +// Unterstützt mehrzeilige WRITE_FILE Blöcke: +// +// TOOL:WRITE_FILE:pfad +// <<< +// inhalt +// >>> +func parseToolCalls(response string) []toolCall { + var calls []toolCall lines := strings.Split(response, "\n") - var toolOutputs []string - hasToolCall := false - for _, line := range lines { - line = strings.TrimSpace(line) + i := 0 + for i < len(lines) { + line := strings.TrimSpace(lines[i]) + if !strings.HasPrefix(line, "TOOL:") { + i++ continue } - hasToolCall = true - result := executeTool(line, workDir) - toolOutputs = append(toolOutputs, result) + + parts := strings.SplitN(line, ":", 3) + if len(parts) < 3 { + i++ + continue + } + + toolName := parts[1] + toolPath := parts[2] + + // WRITE_FILE: Block-Inhalt lesen (<<<...>>>) + if toolName == "WRITE_FILE" { + content, newIndex := readContentBlock(lines, i+1) + calls = append(calls, toolCall{ + name: toolName, + path: toolPath, + content: content, + }) + i = newIndex + continue + } + + calls = append(calls, toolCall{ + name: toolName, + path: toolPath, + }) + i++ } - return strings.Join(toolOutputs, "\n"), hasToolCall + return calls } -func executeTool(toolCall string, workDir string) string { - parts := strings.SplitN(toolCall, ":", 4) - if len(parts) < 3 { - return "ERROR: Ungültiger Tool-Call" +// readContentBlock liest Zeilen zwischen <<< und >>> und gibt den +// bereinigten Inhalt sowie den neuen Zeilenindex zurück. +func readContentBlock(lines []string, startIndex int) (string, int) { + i := startIndex + + // Öffnendes <<< überspringen (optional, falls LLM es ausgibt) + if i < len(lines) && strings.TrimSpace(lines[i]) == "<<<" { + i++ } - toolName := parts[1] - arg1 := parts[2] + var contentLines []string + for i < len(lines) { + trimmed := strings.TrimSpace(lines[i]) + if trimmed == ">>>" { + i++ // >>> konsumieren + break + } + contentLines = append(contentLines, lines[i]) + i++ + } - switch toolName { + return strings.Join(contentLines, "\n"), i +} + +// ─── Tool Execution ─────────────────────────────────────── + +// ExecuteTools parst alle Tool-Calls aus der LLM-Antwort und führt sie aus. +// Gibt den kombinierten Output und true zurück wenn mindestens ein Tool aufgerufen wurde. +func ExecuteTools(response string, workDir string) (string, bool) { + calls := parseToolCalls(response) + if len(calls) == 0 { + return "", false + } + + var outputs []string + for _, call := range calls { + result := executeToolCall(call, workDir) + outputs = append(outputs, result) + } + + return strings.Join(outputs, "\n"), true +} + +func executeToolCall(call toolCall, workDir string) string { + // Sicherheits-Check: Path Traversal verhindern + safePath, err := sanitizePath(workDir, call.path) + if err != nil { + return fmt.Sprintf("ERROR: Ungültiger Pfad %q: %v", call.path, err) + } + + switch call.name { case "READ_FILE": - path := filepath.Join(workDir, arg1) - content, err := os.ReadFile(path) - if err != nil { - return fmt.Sprintf("READ_FILE ERROR: %v", err) - } - return fmt.Sprintf("READ_FILE %s:\n%s", arg1, string(content)) - + return readFile(safePath, call.path) case "WRITE_FILE": - if len(parts) < 4 { - return "ERROR: WRITE_FILE braucht Inhalt" - } - content := parts[3] - path := filepath.Join(workDir, arg1) - - // Verzeichnis anlegen falls nötig - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - return fmt.Sprintf("WRITE_FILE ERROR: %v", err) - } - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - return fmt.Sprintf("WRITE_FILE ERROR: %v", err) - } - return fmt.Sprintf("WRITE_FILE OK: %s geschrieben", arg1) - + return writeFile(safePath, call.path, call.content) case "LIST_FILES": - path := filepath.Join(workDir, arg1) - entries, err := os.ReadDir(path) - if err != nil { - return fmt.Sprintf("LIST_FILES ERROR: %v", err) - } - var files []string - for _, e := range entries { + return listFiles(safePath, call.path) + default: + return fmt.Sprintf("ERROR: Unbekanntes Tool %q", call.name) + } +} + +// ─── Einzelne Tool-Implementierungen ───────────────────── + +func readFile(absPath, displayPath string) string { + content, err := os.ReadFile(absPath) + if err != nil { + return fmt.Sprintf("READ_FILE ERROR: %v", err) + } + return fmt.Sprintf("READ_FILE %s:\n%s", displayPath, string(content)) +} + +func writeFile(absPath, displayPath, content string) string { + if err := os.MkdirAll(filepath.Dir(absPath), 0755); err != nil { + return fmt.Sprintf("WRITE_FILE ERROR: Verzeichnis anlegen fehlgeschlagen: %v", err) + } + if err := os.WriteFile(absPath, []byte(content), 0644); err != nil { + return fmt.Sprintf("WRITE_FILE ERROR: %v", err) + } + return fmt.Sprintf("WRITE_FILE OK: %s geschrieben (%d Bytes)", displayPath, len(content)) +} + +func listFiles(absPath, displayPath string) string { + entries, err := os.ReadDir(absPath) + if err != nil { + return fmt.Sprintf("LIST_FILES ERROR: %v", err) + } + if len(entries) == 0 { + return fmt.Sprintf("LIST_FILES %s: (leer)", displayPath) + } + var files []string + for _, e := range entries { + if e.IsDir() { + files = append(files, e.Name()+"/") + } else { files = append(files, e.Name()) } - return fmt.Sprintf("LIST_FILES %s:\n%s", arg1, strings.Join(files, "\n")) + } + return fmt.Sprintf("LIST_FILES %s:\n%s", displayPath, strings.Join(files, "\n")) +} + +// ─── Sicherheit ─────────────────────────────────────────── + +// sanitizePath stellt sicher dass der Pfad innerhalb des workDir bleibt. +// Verhindert Directory Traversal wie ../../etc/passwd +func sanitizePath(workDir, relPath string) (string, error) { + // Absoluten Zielpfad berechnen + abs := filepath.Join(workDir, relPath) + abs = filepath.Clean(abs) + + // Muss mit workDir beginnen + workDirClean := filepath.Clean(workDir) + if !strings.HasPrefix(abs, workDirClean+string(filepath.Separator)) && + abs != workDirClean { + return "", fmt.Errorf("Pfad außerhalb des Arbeitsverzeichnisses") } - return fmt.Sprintf("ERROR: Unbekanntes Tool: %s", toolName) + return abs, nil }