Agent
Agent
import (
"context"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/llm/prompt"
"github.com/charmbracelet/crush/internal/llm/provider"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/shell"
)
// Common errors
var (
ErrRequestCancelled = errors.New("request canceled by user")
ErrSessionBusy = errors.New("session is currently processing another
request")
)
const (
AgentEventTypeError AgentEventType = "error"
AgentEventTypeResponse AgentEventType = "response"
AgentEventTypeSummarize AgentEventType = "summarize"
)
// When summarizing
SessionID string
Progress string
Done bool
}
tools *csync.LazySlice[tools.BaseTool]
provider provider.Provider
providerID string
titleProvider provider.Provider
summarizeProvider provider.Provider
summarizeProviderID string
func NewAgent(
ctx context.Context,
agentCfg config.Agent,
// These services are needed in the tools
permissions permission.Service,
sessions session.Service,
messages message.Service,
history history.Service,
lspClients map[string]*lsp.Client,
) (Service, error) {
cfg := config.Get()
providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
if providerCfg == nil {
return nil, fmt.Errorf("provider for agent %s not found in config",
agentCfg.Name)
}
model := config.Get().GetModelByType(agentCfg.Model)
if model == nil {
return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
}
promptID := agentPromptMap[agentCfg.ID]
if promptID == "" {
promptID = prompt.PromptDefault
}
opts := []provider.ProviderClientOption{
provider.WithModel(agentCfg.Model),
provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID,
config.Get().Options.ContextPaths...)),
}
agentProvider, err := provider.NewProvider(*providerCfg, opts...)
if err != nil {
return nil, err
}
smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
var smallModelProviderCfg *config.ProviderConfig
if smallModelCfg.Provider == providerCfg.ID {
smallModelProviderCfg = providerCfg
} else {
smallModelProviderCfg =
cfg.GetProviderForModel(config.SelectedModelTypeSmall)
if smallModelProviderCfg.ID == "" {
return nil, fmt.Errorf("provider %s not found in config",
smallModelCfg.Provider)
}
}
smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
if smallModel.ID == "" {
return nil, fmt.Errorf("model %s not found in provider %s",
smallModelCfg.Model, smallModelProviderCfg.ID)
}
titleOpts := []provider.ProviderClientOption{
provider.WithModel(config.SelectedModelTypeSmall),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle,
smallModelProviderCfg.ID)),
}
titleProvider, err := provider.NewProvider(*smallModelProviderCfg,
titleOpts...)
if err != nil {
return nil, err
}
summarizeOpts := []provider.ProviderClientOption{
provider.WithModel(config.SelectedModelTypeSmall),
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer,
smallModelProviderCfg.ID)),
}
summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg,
summarizeOpts...)
if err != nil {
return nil, err
}
cwd := cfg.WorkingDir()
allTools := []tools.BaseTool{
tools.NewBashTool(permissions, cwd),
tools.NewDownloadTool(permissions, cwd),
tools.NewEditTool(lspClients, permissions, history, cwd),
tools.NewMultiEditTool(lspClients, permissions, history, cwd),
tools.NewFetchTool(permissions, cwd),
tools.NewGlobTool(cwd),
tools.NewGrepTool(cwd),
tools.NewLsTool(permissions, cwd),
tools.NewSourcegraphTool(),
tools.NewViewTool(lspClients, permissions, cwd),
tools.NewWriteTool(lspClients, permissions, history, cwd),
}
mcpToolsOnce.Do(func() {
mcpTools = doGetMCPTools(ctx, permissions, cfg)
})
allTools = append(allTools, mcpTools...)
if len(lspClients) > 0 {
allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
}
if agentTool != nil {
allTools = append(allTools, agentTool)
}
if agentCfg.AllowedTools == nil {
return allTools
}
return &agent{
Broker: pubsub.NewBroker[AgentEvent](),
agentCfg: agentCfg,
provider: agentProvider,
providerID: string(providerCfg.ID),
messages: messages,
sessions: sessions,
titleProvider: titleProvider,
summarizeProvider: summarizeProvider,
summarizeProviderID: string(smallModelProviderCfg.ID),
activeRequests: csync.NewMap[string, context.CancelFunc](),
tools: csync.NewLazySlice(toolFn),
}, nil
}
if finalResponse == nil {
return fmt.Errorf("no response received from title provider")
}
session.Title = title
_, err = a.sessions.Save(ctx, session)
return err
}
a.activeRequests.Set(sessionID, cancel)
go func() {
slog.Debug("Request started", "sessionID", sessionID)
defer log.RecoverPanic("agent.Run", func() {
events <- a.err(fmt.Errorf("panic while running the agent"))
})
var attachmentParts []message.ContentPart
for _, attachment := range attachments {
attachmentParts = append(attachmentParts,
message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType,
Data: attachment.Content})
}
result := a.processGeneration(genCtx, sessionID, content,
attachmentParts)
if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled)
&& !errors.Is(result.Error, context.Canceled) {
slog.Error(result.Error.Error())
}
slog.Debug("Request completed", "sessionID", sessionID)
a.activeRequests.Del(sessionID)
cancel()
a.Publish(pubsub.CreatedEvent, result)
events <- result
close(events)
}()
return events, nil
}
for {
// Check for cancellation before each iteration
select {
case <-ctx.Done():
return a.err(ctx.Err())
default:
// Continue processing
}
agentMessage, toolResults, err := a.streamAndHandleEvents(ctx,
sessionID, msgHistory)
if err != nil {
if errors.Is(err, context.Canceled) {
agentMessage.AddFinish(message.FinishReasonCanceled,
"Request cancelled", "")
a.messages.Update(context.Background(), agentMessage)
return a.err(ErrRequestCancelled)
}
return a.err(fmt.Errorf("failed to process events: %w", err))
}
if cfg.Options.Debug {
slog.Info("Result", "message", agentMessage.FinishReason(),
"toolResults", toolResults)
}
if (agentMessage.FinishReason() == message.FinishReasonToolUse) &&
toolResults != nil {
// We are not done, we need to respond with the tool response
msgHistory = append(msgHistory, agentMessage, *toolResults)
continue
}
if agentMessage.FinishReason() == "" {
// Kujtim: could not track down where this is happening but this
means its cancelled
agentMessage.AddFinish(message.FinishReasonCanceled, "Request
cancelled", "")
_ = a.messages.Update(context.Background(), agentMessage)
return a.err(ErrRequestCancelled)
}
return AgentEvent{
Type: AgentEventTypeResponse,
Message: agentMessage,
Done: true,
}
}
}
// Add the session and message ID into the context if needed by tools.
ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
go func() {
response, err := tool.Run(ctx, tools.ToolCall{
ID: toolCall.ID,
Name: toolCall.Name,
Input: toolCall.Input,
})
resultChan <- toolExecResult{response: response, err: err}
}()
select {
case <-ctx.Done():
a.finishMessage(context.Background(), &assistantMsg,
message.FinishReasonCanceled, "Request cancelled", "")
// Mark remaining tool calls as cancelled
for j := i; j < len(toolCalls); j++ {
toolResults[j] = message.ToolResult{
ToolCallID: toolCalls[j].ID,
Content: "Tool execution canceled by user",
IsError: true,
}
}
goto out
case result := <-resultChan:
toolResponse = result.response
toolErr = result.err
}
if toolErr != nil {
slog.Error("Tool execution error", "toolCall", toolCall.ID,
"error", toolErr)
if errors.Is(toolErr, permission.ErrorPermissionDenied) {
toolResults[i] = message.ToolResult{
ToolCallID: toolCall.ID,
Content: "Permission denied",
IsError: true,
}
for j := i + 1; j < len(toolCalls); j++ {
toolResults[j] = message.ToolResult{
ToolCallID: toolCalls[j].ID,
Content: "Tool execution canceled by
user",
IsError: true,
}
}
a.finishMessage(ctx, &assistantMsg,
message.FinishReasonPermissionDenied, "Permission denied", "")
break
}
}
toolResults[i] = message.ToolResult{
ToolCallID: toolCall.ID,
Content: toolResponse.Content,
Metadata: toolResponse.Metadata,
IsError: toolResponse.IsError,
}
}
}
out:
if len(toolResults) == 0 {
return assistantMsg, nil, nil
}
parts := make([]message.ContentPart, 0)
for _, tr := range toolResults {
parts = append(parts, tr)
}
msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID,
message.CreateMessageParams{
Role: message.Tool,
Parts: parts,
Provider: a.providerID,
})
if err != nil {
return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool
message: %w", err)
}
switch event.Type {
case provider.EventThinkingDelta:
assistantMsg.AppendReasoningContent(event.Thinking)
return a.messages.Update(ctx, *assistantMsg)
case provider.EventSignatureDelta:
assistantMsg.AppendReasoningSignature(event.Signature)
return a.messages.Update(ctx, *assistantMsg)
case provider.EventContentDelta:
assistantMsg.FinishThinking()
assistantMsg.AppendContent(event.Content)
return a.messages.Update(ctx, *assistantMsg)
case provider.EventToolUseStart:
assistantMsg.FinishThinking()
slog.Info("Tool call started", "toolCall", event.ToolCall)
assistantMsg.AddToolCall(*event.ToolCall)
return a.messages.Update(ctx, *assistantMsg)
case provider.EventToolUseDelta:
assistantMsg.AppendToolCallInput(event.ToolCall.ID,
event.ToolCall.Input)
return a.messages.Update(ctx, *assistantMsg)
case provider.EventToolUseStop:
slog.Info("Finished tool call", "toolCall", event.ToolCall)
assistantMsg.FinishToolCall(event.ToolCall.ID)
return a.messages.Update(ctx, *assistantMsg)
case provider.EventError:
return event.Error
case provider.EventComplete:
assistantMsg.FinishThinking()
assistantMsg.SetToolCalls(event.Response.ToolCalls)
assistantMsg.AddFinish(event.Response.FinishReason, "", "")
if err := a.messages.Update(ctx, *assistantMsg); err != nil {
return fmt.Errorf("failed to update message: %w", err)
}
return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
}
return nil
}
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
model.CostPer1MIn/1e6*float64(usage.InputTokens) +
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
sess.Cost += cost
sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
go func() {
defer a.activeRequests.Del(sessionID + "-summarize")
defer cancel()
event := AgentEvent{
Type: AgentEventTypeSummarize,
Progress: "Starting summarization...",
}
a.Publish(pubsub.CreatedEvent, event)
// Get all messages from the session
msgs, err := a.messages.List(summarizeCtx, sessionID)
if err != nil {
event = AgentEvent{
Type: AgentEventTypeError,
Error: fmt.Errorf("failed to list messages: %w", err),
Done: true,
}
a.Publish(pubsub.CreatedEvent, event)
return
}
summarizeCtx = context.WithValue(summarizeCtx,
tools.SessionIDContextKey, sessionID)
if len(msgs) == 0 {
event = AgentEvent{
Type: AgentEventTypeError,
Error: fmt.Errorf("no messages to summarize"),
Done: true,
}
a.Publish(pubsub.CreatedEvent, event)
return
}
event = AgentEvent{
Type: AgentEventTypeSummarize,
Progress: "Analyzing conversation...",
}
a.Publish(pubsub.CreatedEvent, event)
event = AgentEvent{
Type: AgentEventTypeSummarize,
Progress: "Generating summary...",
}
a.Publish(pubsub.CreatedEvent, event)
summary := strings.TrimSpace(finalResponse.Content)
if summary == "" {
event = AgentEvent{
Type: AgentEventTypeError,
Error: fmt.Errorf("empty summary returned"),
Done: true,
}
a.Publish(pubsub.CreatedEvent, event)
return
}
shell := shell.GetPersistentShell(config.Get().WorkingDir())
summary += "\n\n**Current working directory of the persistent shell**\
n\n" + shell.GetWorkingDir()
event = AgentEvent{
Type: AgentEventTypeSummarize,
Progress: "Creating new session...",
}
a.Publish(pubsub.CreatedEvent, event)
oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
if err != nil {
event = AgentEvent{
Type: AgentEventTypeError,
Error: fmt.Errorf("failed to get session: %w", err),
Done: true,
}
a.Publish(pubsub.CreatedEvent, event)
return
}
// Create a message in the new session with the summary
msg, err := a.messages.Create(summarizeCtx, oldSession.ID,
message.CreateMessageParams{
Role: message.Assistant,
Parts: []message.ContentPart{
message.TextContent{Text: summary},
message.Finish{
Reason: message.FinishReasonEndTurn,
Time: time.Now().Unix(),
},
},
Model: a.summarizeProvider.Model().ID,
Provider: a.summarizeProviderID,
})
if err != nil {
event = AgentEvent{
Type: AgentEventTypeError,
Error: fmt.Errorf("failed to create summary message: %w",
err),
Done: true,
}
a.Publish(pubsub.CreatedEvent, event)
return
}
oldSession.SummaryMessageID = msg.ID
oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
oldSession.PromptTokens = 0
model := a.summarizeProvider.Model()
usage := finalResponse.Usage
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens)
+
model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
model.CostPer1MIn/1e6*float64(usage.InputTokens) +
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
oldSession.Cost += cost
_, err = a.sessions.Save(summarizeCtx, oldSession)
if err != nil {
event = AgentEvent{
Type: AgentEventTypeError,
Error: fmt.Errorf("failed to save session: %w", err),
Done: true,
}
a.Publish(pubsub.CreatedEvent, event)
}
event = AgentEvent{
Type: AgentEventTypeSummarize,
SessionID: oldSession.ID,
Progress: "Summary complete",
Done: true,
}
a.Publish(pubsub.CreatedEvent, event)
// Send final success event with the new session ID
}()
return nil
}
promptID := agentPromptMap[a.agentCfg.ID]
if promptID == "" {
promptID = prompt.PromptDefault
}
opts := []provider.ProviderClientOption{
provider.WithModel(a.agentCfg.Model),
provider.WithSystemMessage(prompt.GetPrompt(promptID,
currentProviderCfg.ID, cfg.Options.ContextPaths...)),
}
// Check if small model provider has changed (affects title and summarize
providers)
smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
var smallModelProviderCfg config.ProviderConfig
if smallModelProviderCfg.ID == "" {
return fmt.Errorf("provider %s not found in config",
smallModelCfg.Provider)
}
provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer,
smallModelProviderCfg.ID)),
}
newSummarizeProvider, err :=
provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
if err != nil {
return fmt.Errorf("failed to create new summarize provider: %w",
err)
}
return nil
}